feat(新功能): 新增agent 工具,图片生成接口

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-06-30 11:29:19 +08:00
parent e087638828
commit 8cfe67c256
5 changed files with 220 additions and 4 deletions

View File

@@ -3,10 +3,11 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
from app.service.generate_image.service_agent_tool_generate_image import AgentToolGenerateImage
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
@@ -304,3 +305,49 @@ async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformMo
}
"""
return await start_pose_transform_batch_generate(request_batch_item)
"""agent tool"""
@router.post("/agent_tool_generate_image")
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **prompt**: 想要生成图片的描述词
- **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别
- **version**: 使用模型版本 fast 或者 high
- **size**: 生成数量
- **version**: 使用模型版本 fast 或者 high
示例参数:
{
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"gender": "male",
"size":2,
"version":"high"
}
"""
try:
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
request_data = request_item.dict()
service = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = service.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
data = {
"image_url_list": image_url_list,
"clothing_category_list": clothing_category_list
}
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
except Exception as e:
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -75,3 +75,16 @@ class BatchGenerateRelightImageModel(BaseModel):
batch_tasks_id: str
user_id: str
batch_data_list: List[RelightItemModel]
"""
agent tool generate image
"""
class AgentTollGenerateImageModel(BaseModel):
prompt: str
category: str
gender: str
version: str
size: int

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import logging
import time
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class AgentToolGenerateImage:
def __init__(self, version):
if version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
def get_result(self, prompt, size, version, category, gender):
image_url_list = []
image_result_list = []
clothing_category_list = []
try:
prompts = [prompt] * 1
modes = ["txt2img"] * 1
images = [self.image.astype(np.float16)] * 1
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
for i in range(size):
if version == "fast":
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
else:
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
image = response.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result)
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
image_result_list.append(image_result)
if category == "sketch":
clothing_category_list = self.get_clothing_category(image_result_list, gender)
return image_url_list, clothing_category_list
except Exception as e:
logger.error(e)
return image_url_list, clothing_category_list
finally:
self.grpc_client.close()
self.triton_client.close()
def preprocess(self, img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_category(self, image):
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
def get_clothing_category(self, images, gender):
category_list = []
for image in images:
sketch = self.preprocess(image)
if gender.lower() == "female":
category_list.append(self.get_category(sketch))
elif gender.lower() == "male":
category = self.get_category(sketch)
if category == 'Trousers' or category == 'Skirt':
category_list.append('Bottoms')
elif category == 'Blouse' or category == 'Dress':
category_list.append('Tops')
else:
category_list.append('Outwear')
else:
category_list.append(self.get_category(sketch))
return category_list
attr_type = pd.read_csv(CATEGORY_PATH)
if __name__ == '__main__':
request_data = {
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"version": "high",
"size": 2,
"gender": "Female",
}
server = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = server.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
print(image_url_list)
print(clothing_category_list)

View File

@@ -186,7 +186,7 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test",

View File

@@ -15,6 +15,9 @@ process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
class ProjectInfoExtraction:
def __init__(self, request_data):
# llm generate brand info init
if len(request_data.image_list) or len(request_data.file_list):
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
else:
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
@@ -55,7 +58,11 @@ class ProjectInfoExtraction:
if __name__ == '__main__':
request_data = ProjectInfoExtractionModel(
prompt="海边派对主题的衬衫设计"
prompt="性别为儿童",
image_list=[
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
],
file_list=[]
)
service = ProjectInfoExtraction(request_data)
print(service.get_result())