Files
FiDA-Gen-Img-Flux2-klein/app/litserve/app.py
zcr b3eda2f7c7
Some checks failed
CI / lint (push) Failing after 11s
1
2026-03-18 11:09:01 +08:00

111 lines
3.6 KiB
Python
Executable File

import torch
import litserve as ls
from PIL import Image
import io
import base64
from diffusers import Flux2KleinPipeline
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
torch_dtype=self.dtype
)
self.pipe.to(device)
self.default_system_prompt = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
"Realistic scale, high material fidelity, sharp focus, 4K quality."
)
def decode_request(self, request):
# 1. 获取 Base64 字符串
image_data = request.get("image", None)
if not image_data:
raise ValueError("Request must contain 'image' field with base64 data.")
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
if "," in image_data:
image_data = image_data.split(",")[1]
# 3. 解码并转换为 PIL 对象 (修复点)
img_bytes = base64.b64decode(image_data)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return {
"image": img,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 28),
"guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"aspect_ratio": request.get("aspect_ratio", "4:3"),
"base_long_edge": request.get("base_long_edge", 1024)
}
@torch.inference_mode()
def predict(self, payload):
# 3. 执行推理逻辑
img = payload.get("image", None)
prompt = payload.get("prompt", "")
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
gen = torch.Generator(device=self.device)
if img:
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
return output
def encode_response(self, output):
# 4. 将 PIL 图像转回 Base64 返回
buffered = io.BytesIO()
output.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}
if __name__ == "__main__":
# 启动服务器
api = FluxKleinAPI()
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)