feat(新功能): 新增agent 工具,图片生成接口

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-06-30 11:29:19 +08:00
parent e087638828
commit 8cfe67c256
5 changed files with 220 additions and 4 deletions

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import logging
import time
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class AgentToolGenerateImage:
def __init__(self, version):
if version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
def get_result(self, prompt, size, version, category, gender):
image_url_list = []
image_result_list = []
clothing_category_list = []
try:
prompts = [prompt] * 1
modes = ["txt2img"] * 1
images = [self.image.astype(np.float16)] * 1
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
for i in range(size):
if version == "fast":
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
else:
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
image = response.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result)
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
image_result_list.append(image_result)
if category == "sketch":
clothing_category_list = self.get_clothing_category(image_result_list, gender)
return image_url_list, clothing_category_list
except Exception as e:
logger.error(e)
return image_url_list, clothing_category_list
finally:
self.grpc_client.close()
self.triton_client.close()
def preprocess(self, img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_category(self, image):
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
def get_clothing_category(self, images, gender):
category_list = []
for image in images:
sketch = self.preprocess(image)
if gender.lower() == "female":
category_list.append(self.get_category(sketch))
elif gender.lower() == "male":
category = self.get_category(sketch)
if category == 'Trousers' or category == 'Skirt':
category_list.append('Bottoms')
elif category == 'Blouse' or category == 'Dress':
category_list.append('Tops')
else:
category_list.append('Outwear')
else:
category_list.append(self.get_category(sketch))
return category_list
attr_type = pd.read_csv(CATEGORY_PATH)
if __name__ == '__main__':
request_data = {
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"version": "high",
"size": 2,
"gender": "Female",
}
server = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = server.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
print(image_url_list)
print(clothing_category_list)

View File

@@ -186,7 +186,7 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test",

View File

@@ -15,7 +15,10 @@ process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
class ProjectInfoExtraction:
def __init__(self, request_data):
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
if len(request_data.image_list) or len(request_data.file_list):
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
else:
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="project_name", description="项目的名称."),
@@ -55,7 +58,11 @@ class ProjectInfoExtraction:
if __name__ == '__main__':
request_data = ProjectInfoExtractionModel(
prompt="海边派对主题的衬衫设计"
prompt="性别为儿童",
image_list=[
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
],
file_list=[]
)
service = ProjectInfoExtraction(request_data)
print(service.get_result())