Files
FiDA-Gen-Img-Flux2-klein/app/litserve_serve.py
zcr c8fb30aaee
Some checks failed
CI / lint (push) Failing after 10s
1
2026-03-18 11:18:49 +08:00

158 lines
6.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import uuid
import torch
from minio import Minio
import litserve as ls
from PIL import Image
import io
from diffusers import Flux2KleinPipeline
from app.utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
cache_dir="./checkpoint/",
torch_dtype=self.dtype
)
self.pipe.to(device)
def decode_request(self, request):
"""
解析请求参数并加载OSS图片的接口函数
接口入参说明request字典结构
----------
request : dict
核心请求参数字典,各字段说明如下:
- input_image_paths : list[str] | None (可选)
OSS图片路径列表格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png"
若不传则为None会导致后续图片加载失败建议必传
- width : int (可选默认值512)
图片宽度默认512像素
- height : int (可选默认值512)
图片高度默认512像素
- bucket_name : str | None (可选)
OSS桶名不传则为None
- object_name : str | None (可选)
OSS对象名文件路径不传则为None
- prompt : str (可选,默认值空字符串)
文本提示词,用于模型推理等场景
- steps : int (可选默认值28)
推理步数,控制模型生成过程的迭代次数
- guidance : float (可选默认值4.0)
引导系数,调节提示词对生成结果的影响程度
- seed : int (可选默认值42)
随机种子,保证生成结果的可复现性
返回值说明
-------
dict
解析后的参数字典,包含:
- bucket_name: 请求中的桶名None/字符串)
- object_name: 请求中的对象名None/字符串)
- images: 从OSS加载的图片列表按input_image_paths顺序
- prompt: 文本提示词(默认空字符串)
- steps: 推理步数默认28
- guidance: 引导系数默认4.0
- seed: 随机种子默认42
- height: 图片高度默认512
- width: 图片宽度默认512
异常说明
-------
- 若input_image_paths非None但格式错误"/"分割且非空可能导致rest[0]索引错误
- 若OSS图片加载失败如路径不存在oss_get_image会抛出对应异常
"""
input_image_paths = request.get("input_image_paths", None)
W = request.get("width", 512)
H = request.get("height", 512)
images = []
if input_image_paths:
for path in input_image_paths:
bucket, *rest = path.split("/", 1)
object_name = rest[0] if rest else ""
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name)
images.append(image)
return {
"bucket_name": request.get("bucket_name", "test"),
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"),
"images": images,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 4),
"guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"height": H,
"width": W
}
@torch.inference_mode()
def predict(self, payload):
# 3. 执行推理逻辑
images = payload.get("images", [])
prompt = payload.get("prompt", "")
gen = torch.Generator(device=self.device)
seed = gen.seed()
print(f"本次使用的随机种子是: {seed}")
if images:
output = self.pipe(
image=images,
prompt=prompt,
height=payload.get("height", 512),
width=payload.get("width", 512),
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
prompt=prompt,
height=payload.get("height", 512),
width=payload.get("width", 512),
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
image_data = io.BytesIO()
output.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name
return output_path
def encode_response(self, output_path):
return {"output_path": output_path}
if __name__ == "__main__":
# 启动服务器
api = FluxKleinAPI()
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)