Files
FiDA-Gen-Img-Flux2-klein/app/litserve_serve.py

139 lines
5.5 KiB
Python
Raw Normal View History

2026-03-23 10:18:30 +08:00
import logging
2026-03-18 10:12:06 +08:00
import uuid
import torch
from minio import Minio
import litserve as ls
import io
from diffusers import Flux2KleinPipeline
2026-03-18 11:09:01 +08:00
from app.utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
2026-03-18 10:12:06 +08:00
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
2026-03-18 11:18:49 +08:00
cache_dir="./checkpoint/",
2026-03-18 10:12:06 +08:00
torch_dtype=self.dtype
)
self.pipe.to(device)
2026-03-23 10:18:30 +08:00
async def decode_request(self, request):
2026-03-18 10:12:06 +08:00
"""
解析请求参数并加载OSS图片的接口函数
接口入参说明request字典结构
----------
request : dict
核心请求参数字典各字段说明如下
- input_image_paths : list[str] | None (可选)
OSS图片路径列表格式为 "bucket/object_name" "test/typical_b/uildi/ng_space_station.png"
若不传则为None会导致后续图片加载失败建议必传
- width : int (可选默认值512)
图片宽度默认512像素
- height : int (可选默认值512)
图片高度默认512像素
- bucket_name : str | None (可选)
OSS桶名不传则为None
- object_name : str | None (可选)
OSS对象名文件路径不传则为None
- prompt : str (可选默认值空字符串)
文本提示词用于模型推理等场景
- steps : int (可选默认值28)
推理步数控制模型生成过程的迭代次数
- guidance : float (可选默认值4.0)
引导系数调节提示词对生成结果的影响程度
返回值说明
-------
dict
解析后的参数字典包含
- bucket_name: 请求中的桶名None/字符串
- object_name: 请求中的对象名None/字符串
- images: 从OSS加载的图片列表按input_image_paths顺序
- prompt: 文本提示词默认空字符串
- steps: 推理步数默认28
- guidance: 引导系数默认4.0
- height: 图片高度默认512
- width: 图片宽度默认512
异常说明
-------
- 若input_image_paths非None但格式错误"/"分割且非空可能导致rest[0]索引错误
- 若OSS图片加载失败如路径不存在oss_get_image会抛出对应异常
"""
input_image_paths = request.get("input_image_paths", None)
W = request.get("width", 512)
H = request.get("height", 512)
images = []
2026-03-18 10:23:31 +08:00
if input_image_paths:
for path in input_image_paths:
bucket, *rest = path.split("/", 1)
object_name = rest[0] if rest else ""
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name)
images.append(image)
2026-03-18 10:12:06 +08:00
return {
2026-03-18 10:23:31 +08:00
"bucket_name": request.get("bucket_name", "test"),
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"),
2026-03-18 10:12:06 +08:00
"images": images,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 4),
"guidance": request.get("guidance", 4.0),
"height": H,
"width": W
}
2026-03-23 10:18:30 +08:00
async def predict(self, payload):
2026-03-18 10:12:06 +08:00
# 3. 执行推理逻辑
images = payload.get("images", [])
prompt = payload.get("prompt", "")
gen = torch.Generator(device=self.device)
2026-03-24 16:08:21 +08:00
seed = gen.seed()
print(f"本次使用的随机种子是: {seed}")
2026-03-18 10:12:06 +08:00
if images:
2026-03-18 11:09:01 +08:00
output = self.pipe(
2026-03-18 10:12:06 +08:00
image=images,
prompt=prompt,
2026-03-18 10:23:31 +08:00
height=payload.get("height", 512),
width=payload.get("width", 512),
2026-03-18 10:12:06 +08:00
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
prompt=prompt,
2026-03-18 10:23:31 +08:00
height=payload.get("height", 512),
width=payload.get("width", 512),
2026-03-18 10:12:06 +08:00
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
image_data = io.BytesIO()
output.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name
2026-03-23 10:18:30 +08:00
logging.info(f"output_path :{output_path}")
2026-03-18 10:12:06 +08:00
return output_path
2026-03-23 10:18:30 +08:00
async def encode_response(self, output_path):
2026-03-18 10:12:06 +08:00
return {"output_path": output_path}
if __name__ == "__main__":
# 启动服务器
2026-03-23 10:18:30 +08:00
api = FluxKleinAPI(enable_async=True)
2026-03-18 10:12:06 +08:00
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)