This commit is contained in:
zcr
2026-03-23 10:18:30 +08:00
parent edf5ef1231
commit 8991927cd9
6 changed files with 22 additions and 44 deletions

View File

@@ -1,17 +0,0 @@
import requests
import base64
# 将你的图片转为 base64
with open("/mnt/data/workspace/Code/flux2/20260123_152354_2steps.png", "rb") as f:
img_base64 = base64.b64encode(f.read()).decode("utf-8")
response = requests.post("http://localhost:8451/predict", json={
# "image": img_base64,
"prompt": "紫色实木窗帘",
"aspect_ratio": "1:1",
"steps": 4
})
# 保存结果
with open("result.png", "wb") as f:
f.write(base64.b64decode(response.json()["image"]))

16
app/litserve_app/client.py Executable file
View File

@@ -0,0 +1,16 @@
import httpx
import asyncio
async def main():
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8451/predict",
json={
"prompt": "紫色实木窗帘",
}
)
print(response.json())
asyncio.run(main())

View File

@@ -1,10 +1,10 @@
import logging
import uuid import uuid
import torch import torch
from minio import Minio from minio import Minio
import litserve as ls import litserve as ls
from PIL import Image
import io import io
from diffusers import Flux2KleinPipeline from diffusers import Flux2KleinPipeline
@@ -13,21 +13,6 @@ from app.utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL,
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI): class FluxKleinAPI(ls.LitAPI):
def setup(self, device): def setup(self, device):
# 1. 模型初始化 # 1. 模型初始化
@@ -42,7 +27,7 @@ class FluxKleinAPI(ls.LitAPI):
) )
self.pipe.to(device) self.pipe.to(device)
def decode_request(self, request): async def decode_request(self, request):
""" """
解析请求参数并加载OSS图片的接口函数 解析请求参数并加载OSS图片的接口函数
@@ -67,8 +52,6 @@ class FluxKleinAPI(ls.LitAPI):
推理步数,控制模型生成过程的迭代次数 推理步数,控制模型生成过程的迭代次数
- guidance : float (可选默认值4.0) - guidance : float (可选默认值4.0)
引导系数,调节提示词对生成结果的影响程度 引导系数,调节提示词对生成结果的影响程度
- seed : int (可选默认值42)
随机种子,保证生成结果的可复现性
返回值说明 返回值说明
------- -------
@@ -80,7 +63,6 @@ class FluxKleinAPI(ls.LitAPI):
- prompt: 文本提示词(默认空字符串) - prompt: 文本提示词(默认空字符串)
- steps: 推理步数默认28 - steps: 推理步数默认28
- guidance: 引导系数默认4.0 - guidance: 引导系数默认4.0
- seed: 随机种子默认42
- height: 图片高度默认512 - height: 图片高度默认512
- width: 图片宽度默认512 - width: 图片宽度默认512
@@ -106,19 +88,15 @@ class FluxKleinAPI(ls.LitAPI):
"prompt": request.get("prompt", ""), "prompt": request.get("prompt", ""),
"steps": request.get("steps", 4), "steps": request.get("steps", 4),
"guidance": request.get("guidance", 4.0), "guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"height": H, "height": H,
"width": W "width": W
} }
@torch.inference_mode() async def predict(self, payload):
def predict(self, payload):
# 3. 执行推理逻辑 # 3. 执行推理逻辑
images = payload.get("images", []) images = payload.get("images", [])
prompt = payload.get("prompt", "") prompt = payload.get("prompt", "")
gen = torch.Generator(device=self.device) gen = torch.Generator(device=self.device)
seed = gen.seed()
print(f"本次使用的随机种子是: {seed}")
if images: if images:
output = self.pipe( output = self.pipe(
image=images, image=images,
@@ -144,14 +122,15 @@ class FluxKleinAPI(ls.LitAPI):
image_bytes = image_data.read() image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes) req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name output_path = req.bucket_name + "/" + req.object_name
logging.info(f"output_path :{output_path}")
return output_path return output_path
def encode_response(self, output_path): async def encode_response(self, output_path):
return {"output_path": output_path} return {"output_path": output_path}
if __name__ == "__main__": if __name__ == "__main__":
# 启动服务器 # 启动服务器
api = FluxKleinAPI() api = FluxKleinAPI(enable_async=True)
server = ls.LitServer(api, accelerator="cuda", devices=1) server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451) server.run(port=8451)