feat(新功能): 新增agent 工具,图片生成接口
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -3,10 +3,11 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
|
||||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel
|
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
|
||||||
from app.schemas.pose_transform import BatchPoseTransformModel
|
from app.schemas.pose_transform import BatchPoseTransformModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
|
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
|
||||||
|
from app.service.generate_image.service_agent_tool_generate_image import AgentToolGenerateImage
|
||||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||||
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
||||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||||
@@ -304,3 +305,49 @@ async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformMo
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return await start_pose_transform_batch_generate(request_batch_item)
|
return await start_pose_transform_batch_generate(request_batch_item)
|
||||||
|
|
||||||
|
|
||||||
|
"""agent tool"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/agent_tool_generate_image")
|
||||||
|
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **prompt**: 想要生成图片的描述词
|
||||||
|
- **category**: 生成图片的类别,sketch print 等等
|
||||||
|
- **gender**: 生成sketch专用,服装类别
|
||||||
|
- **version**: 使用模型版本 fast 或者 high
|
||||||
|
- **size**: 生成数量
|
||||||
|
- **version**: 使用模型版本 fast 或者 high
|
||||||
|
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||||
|
"category": "sketch",
|
||||||
|
"gender": "male",
|
||||||
|
"size":2,
|
||||||
|
"version":"high"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||||
|
request_data = request_item.dict()
|
||||||
|
service = AgentToolGenerateImage(request_data['version'])
|
||||||
|
image_url_list, clothing_category_list = service.get_result(
|
||||||
|
prompt=request_data['prompt'],
|
||||||
|
size=request_data['size'],
|
||||||
|
version=request_data['version'],
|
||||||
|
category=request_data['category'],
|
||||||
|
gender=request_data['gender']
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
"image_url_list": image_url_list,
|
||||||
|
"clothing_category_list": clothing_category_list
|
||||||
|
}
|
||||||
|
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data)
|
||||||
|
|||||||
@@ -75,3 +75,16 @@ class BatchGenerateRelightImageModel(BaseModel):
|
|||||||
batch_tasks_id: str
|
batch_tasks_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
batch_data_list: List[RelightItemModel]
|
batch_data_list: List[RelightItemModel]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
agent tool generate image
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTollGenerateImageModel(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
category: str
|
||||||
|
gender: str
|
||||||
|
version: str
|
||||||
|
size: int
|
||||||
|
|||||||
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :trinity_client
|
||||||
|
@File :service_att_recognition.py
|
||||||
|
@Author :周成融
|
||||||
|
@Date :2023/7/26 12:01:05
|
||||||
|
@detail :
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import cv2
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import tritonclient.http as httpclient
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
from minio import Minio
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
from app.core.config import *
|
||||||
|
from app.service.utils.new_oss_client import oss_upload_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolGenerateImage:
|
||||||
|
def __init__(self, version):
|
||||||
|
if version == "fast":
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||||
|
else:
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||||
|
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||||
|
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||||
|
|
||||||
|
def get_result(self, prompt, size, version, category, gender):
|
||||||
|
|
||||||
|
image_url_list = []
|
||||||
|
image_result_list = []
|
||||||
|
clothing_category_list = []
|
||||||
|
try:
|
||||||
|
prompts = [prompt] * 1
|
||||||
|
modes = ["txt2img"] * 1
|
||||||
|
images = [self.image.astype(np.float16)] * 1
|
||||||
|
|
||||||
|
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
|
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||||
|
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||||
|
|
||||||
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||||
|
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||||
|
|
||||||
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
input_mode.set_data_from_numpy(mode_obj)
|
||||||
|
|
||||||
|
inputs = [input_text, input_image, input_mode]
|
||||||
|
for i in range(size):
|
||||||
|
if version == "fast":
|
||||||
|
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||||
|
else:
|
||||||
|
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||||
|
image = response.as_numpy("generated_image")
|
||||||
|
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||||
|
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
||||||
|
|
||||||
|
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
|
||||||
|
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
|
||||||
|
image_result_list.append(image_result)
|
||||||
|
|
||||||
|
if category == "sketch":
|
||||||
|
clothing_category_list = self.get_clothing_category(image_result_list, gender)
|
||||||
|
|
||||||
|
return image_url_list, clothing_category_list
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return image_url_list, clothing_category_list
|
||||||
|
finally:
|
||||||
|
self.grpc_client.close()
|
||||||
|
self.triton_client.close()
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
img_scale = (224, 224)
|
||||||
|
img = cv2.resize(img, img_scale)
|
||||||
|
img = mmcv.imnormalize(
|
||||||
|
img,
|
||||||
|
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
|
||||||
|
to_rgb=True)
|
||||||
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||||
|
return preprocessed_img
|
||||||
|
|
||||||
|
def get_category(self, image):
|
||||||
|
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
|
||||||
|
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||||
|
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||||
|
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||||
|
scores = inference_output.detach().numpy()
|
||||||
|
colattr = list(attr_type['labelName'])
|
||||||
|
maxsc = np.max(scores[0][:5])
|
||||||
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||||
|
return colattr[indexs[0]]
|
||||||
|
|
||||||
|
def get_clothing_category(self, images, gender):
|
||||||
|
category_list = []
|
||||||
|
for image in images:
|
||||||
|
sketch = self.preprocess(image)
|
||||||
|
if gender.lower() == "female":
|
||||||
|
category_list.append(self.get_category(sketch))
|
||||||
|
elif gender.lower() == "male":
|
||||||
|
category = self.get_category(sketch)
|
||||||
|
if category == 'Trousers' or category == 'Skirt':
|
||||||
|
category_list.append('Bottoms')
|
||||||
|
elif category == 'Blouse' or category == 'Dress':
|
||||||
|
category_list.append('Tops')
|
||||||
|
else:
|
||||||
|
category_list.append('Outwear')
|
||||||
|
else:
|
||||||
|
category_list.append(self.get_category(sketch))
|
||||||
|
return category_list
|
||||||
|
|
||||||
|
|
||||||
|
attr_type = pd.read_csv(CATEGORY_PATH)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = {
|
||||||
|
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||||
|
"category": "sketch",
|
||||||
|
"version": "high",
|
||||||
|
"size": 2,
|
||||||
|
"gender": "Female",
|
||||||
|
}
|
||||||
|
server = AgentToolGenerateImage(request_data['version'])
|
||||||
|
image_url_list, clothing_category_list = server.get_result(
|
||||||
|
prompt=request_data['prompt'],
|
||||||
|
size=request_data['size'],
|
||||||
|
version=request_data['version'],
|
||||||
|
category=request_data['category'],
|
||||||
|
gender=request_data['gender']
|
||||||
|
)
|
||||||
|
|
||||||
|
print(image_url_list)
|
||||||
|
print(clothing_category_list)
|
||||||
@@ -186,7 +186,7 @@ def infer_cancel(tasks_id):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
rd = GenerateImageModel(
|
rd = GenerateImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
|
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
|
||||||
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||||
mode='txt2img',
|
mode='txt2img',
|
||||||
category="test",
|
category="test",
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
|
|||||||
class ProjectInfoExtraction:
|
class ProjectInfoExtraction:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
# llm generate brand info init
|
# llm generate brand info init
|
||||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
if len(request_data.image_list) or len(request_data.file_list):
|
||||||
|
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||||
|
else:
|
||||||
|
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||||
|
|
||||||
self.response_schemas = [
|
self.response_schemas = [
|
||||||
ResponseSchema(name="project_name", description="项目的名称."),
|
ResponseSchema(name="project_name", description="项目的名称."),
|
||||||
@@ -55,7 +58,11 @@ class ProjectInfoExtraction:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
request_data = ProjectInfoExtractionModel(
|
request_data = ProjectInfoExtractionModel(
|
||||||
prompt="海边派对主题的衬衫设计"
|
prompt="性别为儿童",
|
||||||
|
image_list=[
|
||||||
|
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
|
||||||
|
],
|
||||||
|
file_list=[]
|
||||||
)
|
)
|
||||||
service = ProjectInfoExtraction(request_data)
|
service = ProjectInfoExtraction(request_data)
|
||||||
print(service.get_result())
|
print(service.get_result())
|
||||||
|
|||||||
Reference in New Issue
Block a user