From 453a8fa33e4bdef630552496ebdba204f9956d84 Mon Sep 17 00:00:00 2001 From: zcr Date: Mon, 30 Mar 2026 16:38:42 +0800 Subject: [PATCH] 1 --- app/litserve_serve.py | 45 ++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/app/litserve_serve.py b/app/litserve_serve.py index f5f8814..1e8f7cc 100755 --- a/app/litserve_serve.py +++ b/app/litserve_serve.py @@ -35,6 +35,8 @@ class FluxKleinAPI(ls.LitAPI): ---------- request : dict 核心请求参数字典,各字段说明如下: + - num_images_per_prompt = 1, + 生成图片数量 - input_image_paths : list[str] | None (可选) OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png") 若不传则为None,会导致后续图片加载失败,建议必传 @@ -61,6 +63,7 @@ class FluxKleinAPI(ls.LitAPI): - object_name: 请求中的对象名(None/字符串) - images: 从OSS加载的图片列表(按input_image_paths顺序) - prompt: 文本提示词(默认空字符串) + - num_images_per_prompt: 生成图片数量 默认为1 - steps: 推理步数(默认28) - guidance: 引导系数(默认4.0) - height: 图片高度(默认512) @@ -72,6 +75,7 @@ class FluxKleinAPI(ls.LitAPI): - 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常 """ input_image_paths = request.get("input_image_paths", None) + num_images_per_prompt = request.get("num_images_per_prompt", 1) W = request.get("width", 512) H = request.get("height", 512) images = [] @@ -83,13 +87,14 @@ class FluxKleinAPI(ls.LitAPI): images.append(image) return { "bucket_name": request.get("bucket_name", "test"), - "object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), + "object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}"), "images": images, "prompt": request.get("prompt", ""), "steps": request.get("steps", 4), "guidance": request.get("guidance", 4.0), "height": H, - "width": W + "width": W, + "num_images_per_prompt": num_images_per_prompt } async def predict(self, payload): @@ -100,35 +105,43 @@ class FluxKleinAPI(ls.LitAPI): seed = gen.seed() print(f"本次使用的随机种子是: {seed}") if images: - output = self.pipe( + outputs = self.pipe( image=images, + num_images_per_prompt=payload["num_images_per_prompt"], prompt=prompt, height=payload.get("height", 512), width=payload.get("width", 512), guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, - ).images[0] + ).images else: - output = self.pipe( + outputs = self.pipe( prompt=prompt, + num_images_per_prompt=payload["num_images_per_prompt"], height=payload.get("height", 512), width=payload.get("width", 512), guidance_scale=payload["guidance"], num_inference_steps=payload["steps"], generator=gen, - ).images[0] - image_data = io.BytesIO() - output.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes) - output_path = req.bucket_name + "/" + req.object_name - logging.info(f"output_path :{output_path}") - return output_path + ).images + outputs_path = [] - async def encode_response(self, output_path): - return {"output_path": output_path} + object_name = payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}") + + for i, output in enumerate(outputs): + image_data = io.BytesIO() + output.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=f"{object_name}-{i}.png", image_bytes=image_bytes) + output_path = req.bucket_name + "/" + req.object_name + outputs_path.append(output_path) + logging.info(f"outputs_path :{outputs_path}") + return outputs_path + + async def encode_response(self, outputs_path): + return {"output_path": outputs_path} if __name__ == "__main__":