1
This commit is contained in:
@@ -35,8 +35,6 @@ class FluxKleinAPI(ls.LitAPI):
|
|||||||
----------
|
----------
|
||||||
request : dict
|
request : dict
|
||||||
核心请求参数字典,各字段说明如下:
|
核心请求参数字典,各字段说明如下:
|
||||||
- num_images_per_prompt = 1,
|
|
||||||
生成图片数量
|
|
||||||
- input_image_paths : list[str] | None (可选)
|
- input_image_paths : list[str] | None (可选)
|
||||||
OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png")
|
OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png")
|
||||||
若不传则为None,会导致后续图片加载失败,建议必传
|
若不传则为None,会导致后续图片加载失败,建议必传
|
||||||
@@ -63,7 +61,6 @@ class FluxKleinAPI(ls.LitAPI):
|
|||||||
- object_name: 请求中的对象名(None/字符串)
|
- object_name: 请求中的对象名(None/字符串)
|
||||||
- images: 从OSS加载的图片列表(按input_image_paths顺序)
|
- images: 从OSS加载的图片列表(按input_image_paths顺序)
|
||||||
- prompt: 文本提示词(默认空字符串)
|
- prompt: 文本提示词(默认空字符串)
|
||||||
- num_images_per_prompt: 生成图片数量 默认为1
|
|
||||||
- steps: 推理步数(默认28)
|
- steps: 推理步数(默认28)
|
||||||
- guidance: 引导系数(默认4.0)
|
- guidance: 引导系数(默认4.0)
|
||||||
- height: 图片高度(默认512)
|
- height: 图片高度(默认512)
|
||||||
@@ -75,7 +72,6 @@ class FluxKleinAPI(ls.LitAPI):
|
|||||||
- 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常
|
- 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常
|
||||||
"""
|
"""
|
||||||
input_image_paths = request.get("input_image_paths", None)
|
input_image_paths = request.get("input_image_paths", None)
|
||||||
num_images_per_prompt = request.get("num_images_per_prompt", 1)
|
|
||||||
W = request.get("width", 512)
|
W = request.get("width", 512)
|
||||||
H = request.get("height", 512)
|
H = request.get("height", 512)
|
||||||
images = []
|
images = []
|
||||||
@@ -87,14 +83,13 @@ class FluxKleinAPI(ls.LitAPI):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
return {
|
return {
|
||||||
"bucket_name": request.get("bucket_name", "test"),
|
"bucket_name": request.get("bucket_name", "test"),
|
||||||
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}"),
|
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"),
|
||||||
"images": images,
|
"images": images,
|
||||||
"prompt": request.get("prompt", ""),
|
"prompt": request.get("prompt", ""),
|
||||||
"steps": request.get("steps", 4),
|
"steps": request.get("steps", 4),
|
||||||
"guidance": request.get("guidance", 4.0),
|
"guidance": request.get("guidance", 4.0),
|
||||||
"height": H,
|
"height": H,
|
||||||
"width": W,
|
"width": W
|
||||||
"num_images_per_prompt": num_images_per_prompt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def predict(self, payload):
|
async def predict(self, payload):
|
||||||
@@ -105,43 +100,35 @@ class FluxKleinAPI(ls.LitAPI):
|
|||||||
seed = gen.seed()
|
seed = gen.seed()
|
||||||
print(f"本次使用的随机种子是: {seed}")
|
print(f"本次使用的随机种子是: {seed}")
|
||||||
if images:
|
if images:
|
||||||
outputs = self.pipe(
|
output = self.pipe(
|
||||||
image=images,
|
image=images,
|
||||||
num_images_per_prompt=payload["num_images_per_prompt"],
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
height=payload.get("height", 512),
|
height=payload.get("height", 512),
|
||||||
width=payload.get("width", 512),
|
width=payload.get("width", 512),
|
||||||
guidance_scale=payload["guidance"],
|
guidance_scale=payload["guidance"],
|
||||||
num_inference_steps=payload["steps"],
|
num_inference_steps=payload["steps"],
|
||||||
generator=gen,
|
generator=gen,
|
||||||
).images
|
).images[0]
|
||||||
else:
|
else:
|
||||||
outputs = self.pipe(
|
output = self.pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_images_per_prompt=payload["num_images_per_prompt"],
|
|
||||||
height=payload.get("height", 512),
|
height=payload.get("height", 512),
|
||||||
width=payload.get("width", 512),
|
width=payload.get("width", 512),
|
||||||
guidance_scale=payload["guidance"],
|
guidance_scale=payload["guidance"],
|
||||||
num_inference_steps=payload["steps"],
|
num_inference_steps=payload["steps"],
|
||||||
generator=gen,
|
generator=gen,
|
||||||
).images
|
).images[0]
|
||||||
outputs_path = []
|
image_data = io.BytesIO()
|
||||||
|
output.save(image_data, format='PNG')
|
||||||
|
image_data.seek(0)
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
|
||||||
|
output_path = req.bucket_name + "/" + req.object_name
|
||||||
|
logging.info(f"output_path :{output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
object_name = payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}")
|
async def encode_response(self, output_path):
|
||||||
|
return {"output_path": output_path}
|
||||||
for i, output in enumerate(outputs):
|
|
||||||
image_data = io.BytesIO()
|
|
||||||
output.save(image_data, format='PNG')
|
|
||||||
image_data.seek(0)
|
|
||||||
image_bytes = image_data.read()
|
|
||||||
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=f"{object_name}-{i}.png", image_bytes=image_bytes)
|
|
||||||
output_path = req.bucket_name + "/" + req.object_name
|
|
||||||
outputs_path.append(output_path)
|
|
||||||
logging.info(f"outputs_path :{outputs_path}")
|
|
||||||
return outputs_path
|
|
||||||
|
|
||||||
async def encode_response(self, outputs_path):
|
|
||||||
return {"output_path": outputs_path}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user