diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 2706abd..5bd5404 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,10 +3,11 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel from app.schemas.pose_transform import BatchPoseTransformModel from app.schemas.response_template import ResponseModel from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate +from app.service.generate_image.service_agent_tool_generate_image import AgentToolGenerateImage from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel @@ -304,3 +305,49 @@ async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformMo } """ return await start_pose_transform_batch_generate(request_batch_item) + + +"""agent tool""" + + +@router.post("/agent_tool_generate_image") +def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **prompt**: 想要生成图片的描述词 + - **category**: 生成图片的类别,sketch print 等等 + - **gender**: 生成sketch专用,服装类别 + - **version**: 使用模型版本 fast 或者 high + - **size**: 生成数量 + - **version**: 使用模型版本 fast 或者 high + + + 示例参数: + { + "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background", + "category": "sketch", + "gender": "male", + "size":2, + "version":"high" + } + """ + try: + logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}") + request_data = request_item.dict() + service = AgentToolGenerateImage(request_data['version']) + image_url_list, clothing_category_list = service.get_result( + prompt=request_data['prompt'], + size=request_data['size'], + version=request_data['version'], + category=request_data['category'], + gender=request_data['gender'] + ) + data = { + "image_url_list": image_url_list, + "clothing_category_list": clothing_category_list + } + logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}") + except Exception as e: + logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 7d1d864..5062d78 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -75,3 +75,16 @@ class BatchGenerateRelightImageModel(BaseModel): batch_tasks_id: str user_id: str batch_data_list: List[RelightItemModel] + + +""" + agent tool generate image +""" + + +class AgentTollGenerateImageModel(BaseModel): + prompt: str + category: str + gender: str + version: str + size: int diff --git a/app/service/generate_image/service_agent_tool_generate_image.py b/app/service/generate_image/service_agent_tool_generate_image.py new file mode 100644 index 0000000..a5c295c --- /dev/null +++ b/app/service/generate_image/service_agent_tool_generate_image.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import logging +import time +import uuid +import cv2 +import mmcv +import numpy as np +import pandas as pd +import torch +import tritonclient.http as httpclient +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from minio import Minio +from tritonclient.utils import np_to_triton_dtype +from app.core.config import * +from app.service.utils.new_oss_client import oss_upload_image + +logger = logging.getLogger() + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class AgentToolGenerateImage: + def __init__(self, version): + if version == "fast": + self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL) + else: + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + + def get_result(self, prompt, size, version, category, gender): + + image_url_list = [] + image_result_list = [] + clothing_category_list = [] + try: + prompts = [prompt] * 1 + modes = ["txt2img"] * 1 + images = [self.image.astype(np.float16)] * 1 + + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_mode.set_data_from_numpy(mode_obj) + + inputs = [input_text, input_image, input_mode] + for i in range(size): + if version == "fast": + response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0) + else: + response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0) + image = response.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + _, img_byte_array = cv2.imencode('.jpg', image_result) + + req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array) + image_url_list.append(f"{req.bucket_name}/{req.object_name}") + image_result_list.append(image_result) + + if category == "sketch": + clothing_category_list = self.get_clothing_category(image_result_list, gender) + + return image_url_list, clothing_category_list + except Exception as e: + logger.error(e) + return image_url_list, clothing_category_list + finally: + self.grpc_client.close() + self.triton_client.close() + + def preprocess(self, img): + img = mmcv.imread(img) + img_scale = (224, 224) + img = cv2.resize(img, img_scale) + img = mmcv.imnormalize( + img, + mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), + to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img + + def get_category(self, image): + inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f'output__0')) + scores = inference_output.detach().numpy() + colattr = list(attr_type['labelName']) + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + return colattr[indexs[0]] + + def get_clothing_category(self, images, gender): + category_list = [] + for image in images: + sketch = self.preprocess(image) + if gender.lower() == "female": + category_list.append(self.get_category(sketch)) + elif gender.lower() == "male": + category = self.get_category(sketch) + if category == 'Trousers' or category == 'Skirt': + category_list.append('Bottoms') + elif category == 'Blouse' or category == 'Dress': + category_list.append('Tops') + else: + category_list.append('Outwear') + else: + category_list.append(self.get_category(sketch)) + return category_list + + +attr_type = pd.read_csv(CATEGORY_PATH) + +if __name__ == '__main__': + request_data = { + "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background", + "category": "sketch", + "version": "high", + "size": 2, + "gender": "Female", + } + server = AgentToolGenerateImage(request_data['version']) + image_url_list, clothing_category_list = server.get_result( + prompt=request_data['prompt'], + size=request_data['size'], + version=request_data['version'], + category=request_data['category'], + gender=request_data['gender'] + ) + + print(image_url_list) + print(clothing_category_list) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index c3ae2d7..7d00b87 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -186,7 +186,7 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background', + prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view", image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", diff --git a/app/service/project_info_extraction/service.py b/app/service/project_info_extraction/service.py index 40d59ba..cf9df6c 100644 --- a/app/service/project_info_extraction/service.py +++ b/app/service/project_info_extraction/service.py @@ -15,7 +15,10 @@ process = ['SERIES_DESIGN', 'SINGLE_DESIGN'] class ProjectInfoExtraction: def __init__(self, request_data): # llm generate brand info init - self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + if len(request_data.image_list) or len(request_data.file_list): + self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + else: + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") self.response_schemas = [ ResponseSchema(name="project_name", description="项目的名称."), @@ -55,7 +58,11 @@ class ProjectInfoExtraction: if __name__ == '__main__': request_data = ProjectInfoExtractionModel( - prompt="海边派对主题的衬衫设计" + prompt="性别为儿童", + image_list=[ + 'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff' + ], + file_list=[] ) service = ProjectInfoExtraction(request_data) print(service.get_result())