110
app/litserve/app.py
Executable file
110
app/litserve/app.py
Executable file
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import litserve as ls
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
from diffusers import Flux2KleinPipeline
|
||||
|
||||
|
||||
# 保持原有的辅助函数
|
||||
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||
w_str, h_str = aspect_ratio.split(":")
|
||||
w, h = float(w_str), float(h_str)
|
||||
if w >= h:
|
||||
width = base_long_edge
|
||||
height = int(round(base_long_edge * (h / w)))
|
||||
else:
|
||||
height = base_long_edge
|
||||
width = int(round(base_long_edge * (w / h)))
|
||||
width = max(64, (width // 8) * 8)
|
||||
height = max(64, (height // 8) * 8)
|
||||
return width, height
|
||||
|
||||
|
||||
class FluxKleinAPI(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
# 1. 模型初始化
|
||||
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||
self.device = device
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
self.pipe = Flux2KleinPipeline.from_pretrained(
|
||||
self.repo_id,
|
||||
torch_dtype=self.dtype
|
||||
)
|
||||
self.pipe.to(device)
|
||||
|
||||
self.default_system_prompt = (
|
||||
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
|
||||
"Realistic scale, high material fidelity, sharp focus, 4K quality."
|
||||
)
|
||||
|
||||
def decode_request(self, request):
|
||||
# 1. 获取 Base64 字符串
|
||||
image_data = request.get("image", None)
|
||||
if not image_data:
|
||||
raise ValueError("Request must contain 'image' field with base64 data.")
|
||||
|
||||
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
|
||||
if "," in image_data:
|
||||
image_data = image_data.split(",")[1]
|
||||
|
||||
# 3. 解码并转换为 PIL 对象 (修复点)
|
||||
img_bytes = base64.b64decode(image_data)
|
||||
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
|
||||
return {
|
||||
"image": img,
|
||||
"prompt": request.get("prompt", ""),
|
||||
"steps": request.get("steps", 28),
|
||||
"guidance": request.get("guidance", 4.0),
|
||||
"seed": request.get("seed", 42),
|
||||
"aspect_ratio": request.get("aspect_ratio", "4:3"),
|
||||
"base_long_edge": request.get("base_long_edge", 1024)
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def predict(self, payload):
|
||||
# 3. 执行推理逻辑
|
||||
img = payload.get("image", None)
|
||||
prompt = payload.get("prompt", "")
|
||||
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
|
||||
gen = torch.Generator(device=self.device)
|
||||
|
||||
if img:
|
||||
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
|
||||
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
|
||||
output = self.pipe(
|
||||
image=img,
|
||||
prompt=prompt,
|
||||
height=H,
|
||||
width=W,
|
||||
guidance_scale=payload["guidance"],
|
||||
num_inference_steps=payload["steps"],
|
||||
generator=gen,
|
||||
).images[0]
|
||||
else:
|
||||
output = self.pipe(
|
||||
image=img,
|
||||
prompt=prompt,
|
||||
height=H,
|
||||
width=W,
|
||||
guidance_scale=payload["guidance"],
|
||||
num_inference_steps=payload["steps"],
|
||||
generator=gen,
|
||||
).images[0]
|
||||
return output
|
||||
|
||||
def encode_response(self, output):
|
||||
# 4. 将 PIL 图像转回 Base64 返回
|
||||
buffered = io.BytesIO()
|
||||
output.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return {"image": img_str}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
api = FluxKleinAPI()
|
||||
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
||||
server.run(port=8451)
|
||||
Reference in New Issue
Block a user