2026-03-23 10:18:30 +08:00
|
|
|
|
import logging
|
2026-03-18 10:12:06 +08:00
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from minio import Minio
|
|
|
|
|
|
|
|
|
|
|
|
import litserve as ls
|
|
|
|
|
|
import io
|
|
|
|
|
|
from diffusers import Flux2KleinPipeline
|
|
|
|
|
|
|
2026-03-18 11:09:01 +08:00
|
|
|
|
from app.utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
2026-03-18 10:12:06 +08:00
|
|
|
|
|
|
|
|
|
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FluxKleinAPI(ls.LitAPI):
|
|
|
|
|
|
def setup(self, device):
|
|
|
|
|
|
# 1. 模型初始化
|
|
|
|
|
|
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
|
|
|
|
|
self.device = device
|
|
|
|
|
|
self.dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
|
|
self.pipe = Flux2KleinPipeline.from_pretrained(
|
|
|
|
|
|
self.repo_id,
|
2026-03-18 11:18:49 +08:00
|
|
|
|
cache_dir="./checkpoint/",
|
2026-03-18 10:12:06 +08:00
|
|
|
|
torch_dtype=self.dtype
|
|
|
|
|
|
)
|
|
|
|
|
|
self.pipe.to(device)
|
|
|
|
|
|
|
2026-03-23 10:18:30 +08:00
|
|
|
|
async def decode_request(self, request):
|
2026-03-18 10:12:06 +08:00
|
|
|
|
"""
|
|
|
|
|
|
解析请求参数并加载OSS图片的接口函数
|
|
|
|
|
|
|
|
|
|
|
|
接口入参说明(request字典结构):
|
|
|
|
|
|
----------
|
|
|
|
|
|
request : dict
|
|
|
|
|
|
核心请求参数字典,各字段说明如下:
|
|
|
|
|
|
- input_image_paths : list[str] | None (可选)
|
|
|
|
|
|
OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png")
|
|
|
|
|
|
若不传则为None,会导致后续图片加载失败,建议必传
|
|
|
|
|
|
- width : int (可选,默认值512)
|
|
|
|
|
|
图片宽度,默认512像素
|
|
|
|
|
|
- height : int (可选,默认值512)
|
|
|
|
|
|
图片高度,默认512像素
|
|
|
|
|
|
- bucket_name : str | None (可选)
|
|
|
|
|
|
OSS桶名,不传则为None
|
|
|
|
|
|
- object_name : str | None (可选)
|
|
|
|
|
|
OSS对象名(文件路径),不传则为None
|
|
|
|
|
|
- prompt : str (可选,默认值空字符串)
|
|
|
|
|
|
文本提示词,用于模型推理等场景
|
|
|
|
|
|
- steps : int (可选,默认值28)
|
|
|
|
|
|
推理步数,控制模型生成过程的迭代次数
|
|
|
|
|
|
- guidance : float (可选,默认值4.0)
|
|
|
|
|
|
引导系数,调节提示词对生成结果的影响程度
|
|
|
|
|
|
|
|
|
|
|
|
返回值说明
|
|
|
|
|
|
-------
|
|
|
|
|
|
dict
|
|
|
|
|
|
解析后的参数字典,包含:
|
|
|
|
|
|
- bucket_name: 请求中的桶名(None/字符串)
|
|
|
|
|
|
- object_name: 请求中的对象名(None/字符串)
|
|
|
|
|
|
- images: 从OSS加载的图片列表(按input_image_paths顺序)
|
|
|
|
|
|
- prompt: 文本提示词(默认空字符串)
|
|
|
|
|
|
- steps: 推理步数(默认28)
|
|
|
|
|
|
- guidance: 引导系数(默认4.0)
|
|
|
|
|
|
- height: 图片高度(默认512)
|
|
|
|
|
|
- width: 图片宽度(默认512)
|
|
|
|
|
|
|
|
|
|
|
|
异常说明
|
|
|
|
|
|
-------
|
|
|
|
|
|
- 若input_image_paths非None但格式错误(无"/"分割且非空),可能导致rest[0]索引错误
|
|
|
|
|
|
- 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常
|
|
|
|
|
|
"""
|
|
|
|
|
|
input_image_paths = request.get("input_image_paths", None)
|
|
|
|
|
|
W = request.get("width", 512)
|
|
|
|
|
|
H = request.get("height", 512)
|
|
|
|
|
|
images = []
|
2026-03-18 10:23:31 +08:00
|
|
|
|
if input_image_paths:
|
|
|
|
|
|
for path in input_image_paths:
|
|
|
|
|
|
bucket, *rest = path.split("/", 1)
|
|
|
|
|
|
object_name = rest[0] if rest else ""
|
|
|
|
|
|
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
|
|
|
|
|
images.append(image)
|
2026-03-18 10:12:06 +08:00
|
|
|
|
return {
|
2026-03-18 10:23:31 +08:00
|
|
|
|
"bucket_name": request.get("bucket_name", "test"),
|
|
|
|
|
|
"object_name": request.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"),
|
2026-03-18 10:12:06 +08:00
|
|
|
|
"images": images,
|
|
|
|
|
|
"prompt": request.get("prompt", ""),
|
|
|
|
|
|
"steps": request.get("steps", 4),
|
|
|
|
|
|
"guidance": request.get("guidance", 4.0),
|
|
|
|
|
|
"height": H,
|
|
|
|
|
|
"width": W
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-23 10:18:30 +08:00
|
|
|
|
async def predict(self, payload):
|
2026-03-18 10:12:06 +08:00
|
|
|
|
# 3. 执行推理逻辑
|
|
|
|
|
|
images = payload.get("images", [])
|
|
|
|
|
|
prompt = payload.get("prompt", "")
|
|
|
|
|
|
gen = torch.Generator(device=self.device)
|
2026-03-24 16:08:21 +08:00
|
|
|
|
seed = gen.seed()
|
|
|
|
|
|
print(f"本次使用的随机种子是: {seed}")
|
2026-03-18 10:12:06 +08:00
|
|
|
|
if images:
|
2026-03-18 11:09:01 +08:00
|
|
|
|
output = self.pipe(
|
2026-03-18 10:12:06 +08:00
|
|
|
|
image=images,
|
|
|
|
|
|
prompt=prompt,
|
2026-03-18 10:23:31 +08:00
|
|
|
|
height=payload.get("height", 512),
|
|
|
|
|
|
width=payload.get("width", 512),
|
2026-03-18 10:12:06 +08:00
|
|
|
|
guidance_scale=payload["guidance"],
|
|
|
|
|
|
num_inference_steps=payload["steps"],
|
|
|
|
|
|
generator=gen,
|
|
|
|
|
|
).images[0]
|
|
|
|
|
|
else:
|
|
|
|
|
|
output = self.pipe(
|
|
|
|
|
|
prompt=prompt,
|
2026-03-18 10:23:31 +08:00
|
|
|
|
height=payload.get("height", 512),
|
|
|
|
|
|
width=payload.get("width", 512),
|
2026-03-18 10:12:06 +08:00
|
|
|
|
guidance_scale=payload["guidance"],
|
|
|
|
|
|
num_inference_steps=payload["steps"],
|
|
|
|
|
|
generator=gen,
|
|
|
|
|
|
).images[0]
|
|
|
|
|
|
image_data = io.BytesIO()
|
|
|
|
|
|
output.save(image_data, format='PNG')
|
|
|
|
|
|
image_data.seek(0)
|
|
|
|
|
|
image_bytes = image_data.read()
|
|
|
|
|
|
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
|
|
|
|
|
|
output_path = req.bucket_name + "/" + req.object_name
|
2026-03-23 10:18:30 +08:00
|
|
|
|
logging.info(f"output_path :{output_path}")
|
2026-03-18 10:12:06 +08:00
|
|
|
|
return output_path
|
|
|
|
|
|
|
2026-03-23 10:18:30 +08:00
|
|
|
|
async def encode_response(self, output_path):
|
2026-03-18 10:12:06 +08:00
|
|
|
|
return {"output_path": output_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# 启动服务器
|
2026-03-23 10:18:30 +08:00
|
|
|
|
api = FluxKleinAPI(enable_async=True)
|
2026-03-18 10:12:06 +08:00
|
|
|
|
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
|
|
|
|
|
server.run(port=8451)
|