import torch import litserve as ls from PIL import Image import io import base64 from diffusers import Flux2KleinPipeline # 保持原有的辅助函数 def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]: w_str, h_str = aspect_ratio.split(":") w, h = float(w_str), float(h_str) if w >= h: width = base_long_edge height = int(round(base_long_edge * (h / w))) else: height = base_long_edge width = int(round(base_long_edge * (w / h))) width = max(64, (width // 8) * 8) height = max(64, (height // 8) * 8) return width, height class FluxKleinAPI(ls.LitAPI): def setup(self, device): # 1. 模型初始化 self.repo_id = "black-forest-labs/FLUX.2-klein-4B" self.device = device self.dtype = torch.bfloat16 self.pipe = Flux2KleinPipeline.from_pretrained( self.repo_id, torch_dtype=self.dtype ) self.pipe.to(device) self.default_system_prompt = ( "Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... " "Realistic scale, high material fidelity, sharp focus, 4K quality." ) def decode_request(self, request): # 1. 获取 Base64 字符串 image_data = request.get("image", None) if not image_data: raise ValueError("Request must contain 'image' field with base64 data.") # 2. 如果带有 data:image/png;base64, 前缀,去掉它 if "," in image_data: image_data = image_data.split(",")[1] # 3. 解码并转换为 PIL 对象 (修复点) img_bytes = base64.b64decode(image_data) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") return { "image": img, "prompt": request.get("prompt", ""), "steps": request.get("steps", 28), "guidance": request.get("guidance", 4.0), "seed": request.get("seed", 42), "aspect_ratio": request.get("aspect_ratio", "4:3"), "base_long_edge": request.get("base_long_edge", 1024) } @torch.inference_mode() def predict(self, payload): # 3. 执行推理逻辑 img = payload.get("image", None) prompt = payload.get("prompt", "") W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"]) gen = torch.Generator(device=self.device) if img: # 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS) output = self.pipe( image=img, prompt=prompt, height=H, width=W, guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, ).images[0] else: output = self.pipe( image=img, prompt=prompt, height=H, width=W, guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, ).images[0] return output def encode_response(self, output): # 4. 将 PIL 图像转回 Base64 返回 buffered = io.BytesIO() output.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"image": img_str} if __name__ == "__main__": # 启动服务器 api = FluxKleinAPI() server = ls.LitServer(api, accelerator="cuda", devices=1) server.run(port=8451)