feat: 新增flux2klein作为moodboard的localbase 模型 ; fix:

This commit is contained in:
zcr
2026-03-23 10:46:16 +08:00
parent e25f49a776
commit d93c50ce2b
3 changed files with 71 additions and 5 deletions

View File

@@ -1,11 +1,12 @@
import json import json
import logging import logging
import httpx
import requests import requests
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.core.config import settings from app.core.config import settings
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel, Flux2ToProductImgModel, GenerateSloganImageModel from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel, Flux2ToProductImgModel, GenerateSloganImageModel, GenerateImageFlux2KleinModel
from app.schemas.pose_transform import BatchPoseTransformModel from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
@@ -22,6 +23,57 @@ logger = logging.getLogger()
'''generate image''' '''generate image'''
# flux2 klein
@router.post("/generate_image_flux2_klein")
async def generate_image_flux2_klein(request_item: GenerateImageFlux2KleinModel):
"""
创建一个具有以下参数的请求体:
- **bucket_name**: OSS桶名 (必填)
- **object_name**: OSS对象名文件路径(必填)
- **width**: 图片宽度默认1024像素 (非必填,1024)
- **height**: 图片高度默认1024像素 (非必填,默认1024)
- **prompt**: 文本提示词,用于模型推理等场景 (非必填,默认"")
- **steps**: 推理步数,控制模型生成过程的迭代次数 (非必填,默认4)
- **guidance**: 引导系数,调节提示词对生成结果的影响程度 (非必填,默认 4.0 )
### 示例参数:
```
{
"bucket_name": "aida-users",
"object_name": "89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png",
"prompt": "a single item of sketch of dress, 4k, white background"
}
```
### 输出示例:
```
{
"code": 200,
"msg": "OK!",
"data": {
"output_path": "aida-users/89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png"
}
}
```
"""
try:
logger.info(f"generate_image_flux2_gen_img request: {json.dumps(request_item.model_dump(), indent=4)}")
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item.model_dump(),
)
result = resp.json()
logger.info(f"generate_image_flux2_gen_img response: {json.dumps(result, indent=4)}")
return ResponseModel(data=result)
except Exception as e:
logger.warning(f"generate_image_flux2_gen_img Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
# sdxl
@router.post("/generate_image") @router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
""" """
@@ -183,7 +235,6 @@ async def generate_slogan(request_data: GenerateSloganImageModel):
"""product image flux2.0""" """product image flux2.0"""
# @router.post("/img_to_product") # @router.post("/img_to_product")
# async def img_to_product(request_data: Flux2ToProductImgModel): # async def img_to_product(request_data: Flux2ToProductImgModel):
# """ # """

View File

@@ -64,6 +64,9 @@ class Settings(BaseSettings):
# --- Design Callback Java 接口 --- # --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="") JAVA_STREAM_API_URL: str = Field(default='', description="")
# --- flux2 klein model url ---
FLUX2_GEN_IMG_MODEL_URL: str = Field(default='', description="")
# --- 服务器IP --- # --- 服务器IP ---
A6000_SERVICE_HOST: str = Field(default='', description="") A6000_SERVICE_HOST: str = Field(default='', description="")
B_4_X_4090_SERVICE_HOST: str = Field(default='', description="") B_4_X_4090_SERVICE_HOST: str = Field(default='', description="")

View File

@@ -1,6 +1,6 @@
from typing import List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
class GenerateMultiViewModel(BaseModel): class GenerateMultiViewModel(BaseModel):
@@ -8,8 +8,20 @@ class GenerateMultiViewModel(BaseModel):
image_url: str image_url: str
class GenerateImageFlux2KleinModel(BaseModel):
bucket_name: str = Field(..., description="OSS桶名不传则为None")
object_name: str = Field(..., description="OSS对象名文件路径不传则为None")
# input_image_paths: Optional[List[str]] = Field(default=[], description="输入图片路径列表")
width: Optional[int] = Field(default=1024, description="图片宽度默认512像素")
height: Optional[int] = Field(default=1024, description="图片高度默认512像素")
prompt: Optional[str] = Field(default="", description="文本提示词,用于模型推理等场景")
steps: Optional[int] = Field(default=4, description="推理步数,控制模型生成过程的迭代次数")
guidance: Optional[float] = Field(default=4.0, description="引导系数,调节提示词对生成结果的影响程度")
class GenerateImageModel(BaseModel): class GenerateImageModel(BaseModel):
tasks_id: str bucket_name: str
object_name: str
prompt: str prompt: str
image_url: str image_url: str
mode: str mode: str