import uuid import torch from minio import Minio import litserve as ls from PIL import Image import io from diffusers import Flux2KleinPipeline from app.utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # 保持原有的辅助函数 def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]: w_str, h_str = aspect_ratio.split(":") w, h = float(w_str), float(h_str) if w >= h: width = base_long_edge height = int(round(base_long_edge * (h / w))) else: height = base_long_edge width = int(round(base_long_edge * (w / h))) width = max(64, (width // 8) * 8) height = max(64, (height // 8) * 8) return width, height class FluxKleinAPI(ls.LitAPI): def setup(self, device): # 1. 模型初始化 self.repo_id = "black-forest-labs/FLUX.2-klein-4B" self.device = device self.dtype = torch.bfloat16 self.pipe = Flux2KleinPipeline.from_pretrained( self.repo_id, torch_dtype=self.dtype ) self.pipe.to(device) def decode_request(self, request): """ 解析请求参数并加载OSS图片的接口函数 接口入参说明(request字典结构): ---------- request : dict 核心请求参数字典,各字段说明如下: - input_image_paths : list[str] | None (可选) OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png") 若不传则为None,会导致后续图片加载失败,建议必传 - width : int (可选,默认值512) 图片宽度,默认512像素 - height : int (可选,默认值512) 图片高度,默认512像素 - bucket_name : str | None (可选) OSS桶名,不传则为None - object_name : str | None (可选) OSS对象名(文件路径),不传则为None - prompt : str (可选,默认值空字符串) 文本提示词,用于模型推理等场景 - steps : int (可选,默认值28) 推理步数,控制模型生成过程的迭代次数 - guidance : float (可选,默认值4.0) 引导系数,调节提示词对生成结果的影响程度 - seed : int (可选,默认值42) 随机种子,保证生成结果的可复现性 返回值说明 ------- dict 解析后的参数字典,包含: - bucket_name: 请求中的桶名(None/字符串) - object_name: 请求中的对象名(None/字符串) - images: 从OSS加载的图片列表(按input_image_paths顺序) - prompt: 文本提示词(默认空字符串) - steps: 推理步数(默认28) - guidance: 引导系数(默认4.0) - seed: 随机种子(默认42) - height: 图片高度(默认512) - width: 图片宽度(默认512) 异常说明 ------- - 若input_image_paths非None但格式错误(无"/"分割且非空),可能导致rest[0]索引错误 - 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常 """ input_image_paths = request.get("input_image_paths", None) W = request.get("width", 512) H = request.get("height", 512) images = [] if input_image_paths: for path in input_image_paths: bucket, *rest = path.split("/", 1) object_name = rest[0] if rest else "" image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name) images.append(image) return { "bucket_name": request.get("bucket_name", "test"), "object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), "images": images, "prompt": request.get("prompt", ""), "steps": request.get("steps", 4), "guidance": request.get("guidance", 4.0), "seed": request.get("seed", 42), "height": H, "width": W } @torch.inference_mode() def predict(self, payload): # 3. 执行推理逻辑 images = payload.get("images", []) prompt = payload.get("prompt", "") gen = torch.Generator(device=self.device) seed = gen.seed() print(f"本次使用的随机种子是: {seed}") if images: output = self.pipe( image=images, prompt=prompt, height=payload.get("height", 512), width=payload.get("width", 512), guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, ).images[0] else: output = self.pipe( prompt=prompt, height=payload.get("height", 512), width=payload.get("width", 512), guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, ).images[0] image_data = io.BytesIO() output.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes) output_path = req.bucket_name + "/" + req.object_name return output_path def encode_response(self, output_path): return {"output_path": output_path} if __name__ == "__main__": # 启动服务器 api = FluxKleinAPI() server = ls.LitServer(api, accelerator="cuda", devices=1) server.run(port=8451)