feat(新功能): 新增agent 工具,图片生成接口
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class AgentToolGenerateImage:
|
||||
def __init__(self, version):
|
||||
if version == "fast":
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||
else:
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
|
||||
def get_result(self, prompt, size, version, category, gender):
|
||||
|
||||
image_url_list = []
|
||||
image_result_list = []
|
||||
clothing_category_list = []
|
||||
try:
|
||||
prompts = [prompt] * 1
|
||||
modes = ["txt2img"] * 1
|
||||
images = [self.image.astype(np.float16)] * 1
|
||||
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
for i in range(size):
|
||||
if version == "fast":
|
||||
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||
else:
|
||||
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||
image = response.as_numpy("generated_image")
|
||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
||||
|
||||
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
|
||||
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
|
||||
image_result_list.append(image_result)
|
||||
|
||||
if category == "sketch":
|
||||
clothing_category_list = self.get_clothing_category(image_result_list, gender)
|
||||
|
||||
return image_url_list, clothing_category_list
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return image_url_list, clothing_category_list
|
||||
finally:
|
||||
self.grpc_client.close()
|
||||
self.triton_client.close()
|
||||
|
||||
def preprocess(self, img):
|
||||
img = mmcv.imread(img)
|
||||
img_scale = (224, 224)
|
||||
img = cv2.resize(img, img_scale)
|
||||
img = mmcv.imnormalize(
|
||||
img,
|
||||
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
|
||||
to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
def get_category(self, image):
|
||||
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||
scores = inference_output.detach().numpy()
|
||||
colattr = list(attr_type['labelName'])
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
return colattr[indexs[0]]
|
||||
|
||||
def get_clothing_category(self, images, gender):
|
||||
category_list = []
|
||||
for image in images:
|
||||
sketch = self.preprocess(image)
|
||||
if gender.lower() == "female":
|
||||
category_list.append(self.get_category(sketch))
|
||||
elif gender.lower() == "male":
|
||||
category = self.get_category(sketch)
|
||||
if category == 'Trousers' or category == 'Skirt':
|
||||
category_list.append('Bottoms')
|
||||
elif category == 'Blouse' or category == 'Dress':
|
||||
category_list.append('Tops')
|
||||
else:
|
||||
category_list.append('Outwear')
|
||||
else:
|
||||
category_list.append(self.get_category(sketch))
|
||||
return category_list
|
||||
|
||||
|
||||
attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = {
|
||||
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
"category": "sketch",
|
||||
"version": "high",
|
||||
"size": 2,
|
||||
"gender": "Female",
|
||||
}
|
||||
server = AgentToolGenerateImage(request_data['version'])
|
||||
image_url_list, clothing_category_list = server.get_result(
|
||||
prompt=request_data['prompt'],
|
||||
size=request_data['size'],
|
||||
version=request_data['version'],
|
||||
category=request_data['category'],
|
||||
gender=request_data['gender']
|
||||
)
|
||||
|
||||
print(image_url_list)
|
||||
print(clothing_category_list)
|
||||
@@ -186,7 +186,7 @@ def infer_cancel(tasks_id):
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateImageModel(
|
||||
tasks_id="123-89",
|
||||
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
|
||||
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
|
||||
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
mode='txt2img',
|
||||
category="test",
|
||||
|
||||
@@ -15,7 +15,10 @@ process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
|
||||
class ProjectInfoExtraction:
|
||||
def __init__(self, request_data):
|
||||
# llm generate brand info init
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
if len(request_data.image_list) or len(request_data.file_list):
|
||||
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
else:
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
|
||||
self.response_schemas = [
|
||||
ResponseSchema(name="project_name", description="项目的名称."),
|
||||
@@ -55,7 +58,11 @@ class ProjectInfoExtraction:
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ProjectInfoExtractionModel(
|
||||
prompt="海边派对主题的衬衫设计"
|
||||
prompt="性别为儿童",
|
||||
image_list=[
|
||||
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
|
||||
],
|
||||
file_list=[]
|
||||
)
|
||||
service = ProjectInfoExtraction(request_data)
|
||||
print(service.get_result())
|
||||
|
||||
Reference in New Issue
Block a user