111 lines
3.6 KiB
Python
Executable File
111 lines
3.6 KiB
Python
Executable File
import torch
|
|
import litserve as ls
|
|
from PIL import Image
|
|
import io
|
|
import base64
|
|
from diffusers import Flux2KleinPipeline
|
|
|
|
|
|
# 保持原有的辅助函数
|
|
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
|
w_str, h_str = aspect_ratio.split(":")
|
|
w, h = float(w_str), float(h_str)
|
|
if w >= h:
|
|
width = base_long_edge
|
|
height = int(round(base_long_edge * (h / w)))
|
|
else:
|
|
height = base_long_edge
|
|
width = int(round(base_long_edge * (w / h)))
|
|
width = max(64, (width // 8) * 8)
|
|
height = max(64, (height // 8) * 8)
|
|
return width, height
|
|
|
|
|
|
class FluxKleinAPI(ls.LitAPI):
|
|
def setup(self, device):
|
|
# 1. 模型初始化
|
|
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
|
self.device = device
|
|
self.dtype = torch.bfloat16
|
|
|
|
self.pipe = Flux2KleinPipeline.from_pretrained(
|
|
self.repo_id,
|
|
torch_dtype=self.dtype
|
|
)
|
|
self.pipe.to(device)
|
|
|
|
self.default_system_prompt = (
|
|
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
|
|
"Realistic scale, high material fidelity, sharp focus, 4K quality."
|
|
)
|
|
|
|
def decode_request(self, request):
|
|
# 1. 获取 Base64 字符串
|
|
image_data = request.get("image", None)
|
|
if not image_data:
|
|
raise ValueError("Request must contain 'image' field with base64 data.")
|
|
|
|
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
|
|
if "," in image_data:
|
|
image_data = image_data.split(",")[1]
|
|
|
|
# 3. 解码并转换为 PIL 对象 (修复点)
|
|
img_bytes = base64.b64decode(image_data)
|
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
return {
|
|
"image": img,
|
|
"prompt": request.get("prompt", ""),
|
|
"steps": request.get("steps", 28),
|
|
"guidance": request.get("guidance", 4.0),
|
|
"seed": request.get("seed", 42),
|
|
"aspect_ratio": request.get("aspect_ratio", "4:3"),
|
|
"base_long_edge": request.get("base_long_edge", 1024)
|
|
}
|
|
|
|
@torch.inference_mode()
|
|
def predict(self, payload):
|
|
# 3. 执行推理逻辑
|
|
img = payload.get("image", None)
|
|
prompt = payload.get("prompt", "")
|
|
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
|
|
gen = torch.Generator(device=self.device)
|
|
|
|
if img:
|
|
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
|
|
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
|
|
output = self.pipe(
|
|
image=img,
|
|
prompt=prompt,
|
|
height=H,
|
|
width=W,
|
|
guidance_scale=payload["guidance"],
|
|
num_inference_steps=payload["steps"],
|
|
generator=gen,
|
|
).images[0]
|
|
else:
|
|
output = self.pipe(
|
|
image=img,
|
|
prompt=prompt,
|
|
height=H,
|
|
width=W,
|
|
guidance_scale=payload["guidance"],
|
|
num_inference_steps=payload["steps"],
|
|
generator=gen,
|
|
).images[0]
|
|
return output
|
|
|
|
def encode_response(self, output):
|
|
# 4. 将 PIL 图像转回 Base64 返回
|
|
buffered = io.BytesIO()
|
|
output.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
return {"image": img_str}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 启动服务器
|
|
api = FluxKleinAPI()
|
|
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
|
server.run(port=8451)
|