32 lines
878 B
Python
32 lines
878 B
Python
import datetime
|
|
import time
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from diffusers import Flux2KleinPipeline
|
|
|
|
device = "cuda"
|
|
dtype = torch.bfloat16
|
|
|
|
pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, is_distilled=False)
|
|
pipe.to(device) # save some VRAM by offloading the model to CPUsave some VRAM by offloading the model to CPU
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
prompt = ""
|
|
num_inference_steps = 4
|
|
input_image = Image.open("result1.png")
|
|
|
|
start_time = time.time()
|
|
image = pipe(
|
|
image=input_image,
|
|
prompt=prompt,
|
|
height=768,
|
|
width=512,
|
|
guidance_scale=1.0,
|
|
num_inference_steps=num_inference_steps,
|
|
# generator=torch.Generator(device=device).manual_seed(3)
|
|
).images[0]
|
|
|
|
image.save(f"{timestamp}_{num_inference_steps}steps.png")
|
|
print(f"infer time : {time.time() - start_time}")
|