This commit is contained in:
zcr
2026-03-30 16:38:42 +08:00
parent 36dc3d08a1
commit 453a8fa33e

View File

@@ -35,6 +35,8 @@ class FluxKleinAPI(ls.LitAPI):
----------
request : dict
核心请求参数字典,各字段说明如下:
- num_images_per_prompt = 1,
生成图片数量
- input_image_paths : list[str] | None (可选)
OSS图片路径列表格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png"
若不传则为None会导致后续图片加载失败建议必传
@@ -61,6 +63,7 @@ class FluxKleinAPI(ls.LitAPI):
- object_name: 请求中的对象名None/字符串)
- images: 从OSS加载的图片列表按input_image_paths顺序
- prompt: 文本提示词(默认空字符串)
- num_images_per_prompt: 生成图片数量 默认为1
- steps: 推理步数默认28
- guidance: 引导系数默认4.0
- height: 图片高度默认512
@@ -72,6 +75,7 @@ class FluxKleinAPI(ls.LitAPI):
- 若OSS图片加载失败如路径不存在oss_get_image会抛出对应异常
"""
input_image_paths = request.get("input_image_paths", None)
num_images_per_prompt = request.get("num_images_per_prompt", 1)
W = request.get("width", 512)
H = request.get("height", 512)
images = []
@@ -83,13 +87,14 @@ class FluxKleinAPI(ls.LitAPI):
images.append(image)
return {
"bucket_name": request.get("bucket_name", "test"),
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"),
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}"),
"images": images,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 4),
"guidance": request.get("guidance", 4.0),
"height": H,
"width": W
"width": W,
"num_images_per_prompt": num_images_per_prompt
}
async def predict(self, payload):
@@ -100,35 +105,43 @@ class FluxKleinAPI(ls.LitAPI):
seed = gen.seed()
print(f"本次使用的随机种子是: {seed}")
if images:
output = self.pipe(
outputs = self.pipe(
image=images,
num_images_per_prompt=payload["num_images_per_prompt"],
prompt=prompt,
height=payload.get("height", 512),
width=payload.get("width", 512),
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
).images
else:
output = self.pipe(
outputs = self.pipe(
prompt=prompt,
num_images_per_prompt=payload["num_images_per_prompt"],
height=payload.get("height", 512),
width=payload.get("width", 512),
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
image_data = io.BytesIO()
output.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name
logging.info(f"output_path :{output_path}")
return output_path
).images
outputs_path = []
async def encode_response(self, output_path):
return {"output_path": output_path}
object_name = payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}")
for i, output in enumerate(outputs):
image_data = io.BytesIO()
output.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=f"{object_name}-{i}.png", image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name
outputs_path.append(output_path)
logging.info(f"outputs_path :{outputs_path}")
return outputs_path
async def encode_response(self, outputs_path):
return {"output_path": outputs_path}
if __name__ == "__main__":