From 293f90f9d387cf382f8c6b3def369c8c938513f4 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 8 May 2025 17:46:28 +0800 Subject: [PATCH 01/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20p?= =?UTF-8?q?ose=20transform=20mq=E8=BF=9E=E6=8E=A5=E8=B6=85=E6=97=B6bug?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_pose_transform.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py index 07da8de..2bd81ac 100644 --- a/app/service/generate_image/service_pose_transform.py +++ b/app/service/generate_image/service_pose_transform.py @@ -29,9 +29,6 @@ logger = logging.getLogger() class PoseTransformService: def __init__(self, request_data): - if DEBUG is False: - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "pose_transform" @@ -40,7 +37,8 @@ class PoseTransformService: self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] - self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''} + self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', + 'video_url': '', 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) self.redis_client.expire(self.tasks_id, 600) @@ -55,16 +53,20 @@ class PoseTransformService: # 第一帧图像 first_image = Image.fromarray(result_data[0]) - first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png") + first_image_url = upload_first_image(first_image, user_id=self.user_id, + category=f"{self.category}_first_img", + file_name=f"{self.tasks_id}.png") # 上传GIF gif_buffer = BytesIO() imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5) gif_buffer.seek(0) - gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif") + gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", + file_name=f"{self.tasks_id}.gif") # 上传video - video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4") + video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", + file_name=f"{self.tasks_id}.mp4") self.pose_transform_data['status'] = "SUCCESS" self.pose_transform_data['message'] = "success" @@ -82,7 +84,8 @@ class PoseTransformService: try: pose_num = [self.pose_num] * 1 pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1)) - input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype)) + input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, + np_to_triton_dtype(pose_num_obj.dtype)) input_pose_num.set_data_from_numpy(pose_num_obj) image_files = [self.image.astype(np.uint8)] * 1 @@ -90,7 +93,8 @@ class PoseTransformService: input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8") input_image_files.set_data_from_numpy(image_files_obj) - ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback, client_timeout=60000) + ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], + callback=self.callback, client_timeout=60000) time_out = 60000 while time_out > 0: pose_transform_data, _ = self.read_tasks_status() @@ -111,9 +115,22 @@ class PoseTransformService: finally: dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data) - self.connection.close() - logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}") + publish_status(str_pose_transform_data) + logger.info( + f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}") + + +def publish_status(message): + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + channel = connection.channel() + channel.queue_declare(queue=PS_RABBITMQ_QUEUES, durable=True) + channel.basic_publish(exchange='', + routing_key=PS_RABBITMQ_QUEUES, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() def infer_cancel(tasks_id): @@ -125,7 +142,8 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): - image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], + data_type="PIL") # 目标图片的尺寸 target_width = 512 target_height = 768 From 3095d2654e8747785ceaa0b5bb1091d29b286c0c Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 8 May 2025 17:59:14 +0800 Subject: [PATCH 02/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20p?= =?UTF-8?q?ose=20transform=20mq=E8=BF=9E=E6=8E=A5=E8=B6=85=E6=97=B6bug?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_pose_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py index 2bd81ac..78ca227 100644 --- a/app/service/generate_image/service_pose_transform.py +++ b/app/service/generate_image/service_pose_transform.py @@ -104,7 +104,7 @@ class PoseTransformService: elif pose_transform_data['status'] == "SUCCESS": break time_out -= 1 - time.sleep(0.1) + time.sleep(1) pose_transform_data, _ = self.read_tasks_status() return pose_transform_data except Exception as e: From 6cb32d11a8c4a1f333bad099486b9e119cf834f9 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 15 May 2025 14:49:33 +0800 Subject: [PATCH 03/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20=E9=A1=B9=E7=9B=AE=E4=BF=A1=E6=81=AF=E6=8F=90?= =?UTF-8?q?=E5=8F=96/=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=20fix=EF=BC=88?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_extraction_project_info.py | 33 +++++++++++++ app/api/api_route.py | 5 +- app/schemas/project_info_extraction.py | 5 ++ .../service_generate_brand_info.py | 47 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 app/api/api_extraction_project_info.py create mode 100644 app/schemas/project_info_extraction.py create mode 100644 app/service/project_info_extraction/service_generate_brand_info.py diff --git a/app/api/api_extraction_project_info.py b/app/api/api_extraction_project_info.py new file mode 100644 index 0000000..51eb473 --- /dev/null +++ b/app/api/api_extraction_project_info.py @@ -0,0 +1,33 @@ +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel +from app.schemas.response_template import ResponseModel +from app.service.project_info_extraction.service_generate_brand_info import ProjectInfoExtraction + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/extraction_project_info") +def extraction_project_info(request_data: ProjectInfoExtractionModel): + """ + 通过prompt 提取project_name,role ,gender ,style。 + 创建一个具有以下参数的请求体: + - **prompt**: + + 示例参数: + { + "prompt": "海边派对主题的系列设计" + } + """ + try: + logger.info(f"extraction_project_info request item is : @@@@@@:{request_data}") + service = ProjectInfoExtraction(request_data) + data = service.get_result() + logger.info(f"extraction_project_info response @@@@@@:{data}") + except Exception as e: + logger.warning(f"extraction_project_info Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index b82c942..d85cbce 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -5,16 +5,16 @@ from app.api import api_attribute_retrieve, api_query_image from app.api import api_brand_dna from app.api import api_brighten from app.api import api_chat_robot +from app.api import api_clothing_seg from app.api import api_design from app.api import api_design_pre_processing +from app.api import api_extraction_project_info from app.api import api_generate_image from app.api import api_image2sketch from app.api import api_mannequins_edit from app.api import api_pose_transform from app.api import api_prompt_generation -from app.api import api_clothing_seg from app.api import api_super_resolution -from app.api import api_recommendation from app.api import api_test router = APIRouter() @@ -36,3 +36,4 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") +router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api") diff --git a/app/schemas/project_info_extraction.py b/app/schemas/project_info_extraction.py new file mode 100644 index 0000000..90def8b --- /dev/null +++ b/app/schemas/project_info_extraction.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class ProjectInfoExtractionModel(BaseModel): + prompt: str diff --git a/app/service/project_info_extraction/service_generate_brand_info.py b/app/service/project_info_extraction/service_generate_brand_info.py new file mode 100644 index 0000000..8ee7bcd --- /dev/null +++ b/app/service/project_info_extraction/service_generate_brand_info.py @@ -0,0 +1,47 @@ +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_community.chat_models import ChatTongyi +from langchain_core.prompts import PromptTemplate + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel + + +class ProjectInfoExtraction: + def __init__(self, request_data): + # llm generate brand info init + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + + self.response_schemas = [ + ResponseSchema(name="project_name", description="project name."), + ResponseSchema(name="role", description="The target role of the project."), + ResponseSchema(name="gender", description="The gender targeted by the project."), + ResponseSchema(name="style", description="Project style.") + ] + self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) + self.format_instructions = self.output_parser.get_format_instructions() + self.prompt = PromptTemplate( + template="你是一个时装品牌的设计师助理。根据用户输入提取出project_name,role ,gender ,style ." + "gender部分请用以下:menswear,womenswear,childrenwear,如果全部都适用即all." + "如果没有以上内容,需要你根据用户输入随意发挥.\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": self.format_instructions} + ) + self._input = self.prompt.format_prompt(question=request_data.prompt) + + self.result_data = {} + + def get_result(self): + self.llm_extraction_project_info() + return self.result_data + + def llm_extraction_project_info(self): + output = self.model(self._input.to_messages()) + project_info = self.output_parser.parse(output.content) + self.result_data = project_info + + +if __name__ == '__main__': + request_data = ProjectInfoExtractionModel( + prompt="海边派对主题的系列设计" + ) + service = ProjectInfoExtraction(request_data) + print(service.get_result()) From e4141b9e65b0ce95bb7a436694925d5a5fa561d5 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 15 May 2025 14:59:29 +0800 Subject: [PATCH 04/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20=E9=A1=B9=E7=9B=AE=E4=BF=A1=E6=81=AF=E6=8F=90?= =?UTF-8?q?=E5=8F=96/=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=20fix=EF=BC=88?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_route.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/api/api_route.py b/app/api/api_route.py index d85cbce..9b48c5a 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -14,6 +14,7 @@ from app.api import api_image2sketch from app.api import api_mannequins_edit from app.api import api_pose_transform from app.api import api_prompt_generation +from app.api import api_recommendation from app.api import api_super_resolution from app.api import api_test From 3a28a7e4b91b9911afac05da32976a1fed865163 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 15 May 2025 16:40:58 +0800 Subject: [PATCH 05/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20=E9=A1=B9=E7=9B=AE=E4=BF=A1=E6=81=AF=E6=8F=90?= =?UTF-8?q?=E5=8F=96/=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=20fix=EF=BC=88?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_extraction_project_info.py | 2 +- .../project_info_extraction/service.py | 61 +++++++++++++++++++ .../service_generate_brand_info.py | 47 -------------- 3 files changed, 62 insertions(+), 48 deletions(-) create mode 100644 app/service/project_info_extraction/service.py delete mode 100644 app/service/project_info_extraction/service_generate_brand_info.py diff --git a/app/api/api_extraction_project_info.py b/app/api/api_extraction_project_info.py index 51eb473..ad55552 100644 --- a/app/api/api_extraction_project_info.py +++ b/app/api/api_extraction_project_info.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException from app.schemas.project_info_extraction import ProjectInfoExtractionModel from app.schemas.response_template import ResponseModel -from app.service.project_info_extraction.service_generate_brand_info import ProjectInfoExtraction +from app.service.project_info_extraction.service import ProjectInfoExtraction router = APIRouter() logger = logging.getLogger() diff --git a/app/service/project_info_extraction/service.py b/app/service/project_info_extraction/service.py new file mode 100644 index 0000000..40d59ba --- /dev/null +++ b/app/service/project_info_extraction/service.py @@ -0,0 +1,61 @@ +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_community.chat_models import ChatTongyi +from langchain_core.prompts import PromptTemplate + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel + +style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD', + 'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI'] +position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt'] +gender = ['Female', 'Male'] +age_group = ['Adult', 'Child'] +process = ['SERIES_DESIGN', 'SINGLE_DESIGN'] + + +class ProjectInfoExtraction: + def __init__(self, request_data): + # llm generate brand info init + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + + self.response_schemas = [ + ResponseSchema(name="project_name", description="项目的名称."), + ResponseSchema(name="process", description="项目的类型,单品还是系列."), + ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."), + ResponseSchema(name="gender", description="项目设计服装的受众性别."), + ResponseSchema(name="position", description="项目单品设计的部位."), + ResponseSchema(name="style", description="项目的设计风格.") + ] + self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) + self.format_instructions = self.output_parser.get_format_instructions() + self.prompt = PromptTemplate( + template="你是一个时装品牌的设计师助理。根据用户输入提取出" + "[project_name] : 项目的名称," + f"[process] : 项目的类型,从{process}选择." + f"[ageGroup] : 服装的受众,从{age_group}选择." + f"[gender] : 服装的适用性别,从{gender}选择" + f"[position] : single_design的部位,如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN,这项为空" + f"[style] : 设计的风格,从{style}中选择" + ".\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": self.format_instructions} + ) + self._input = self.prompt.format_prompt(question=request_data.prompt) + + self.result_data = {} + + def get_result(self): + self.llm_extraction_project_info() + return self.result_data + + def llm_extraction_project_info(self): + output = self.model(self._input.to_messages()) + project_info = self.output_parser.parse(output.content) + self.result_data = project_info + + +if __name__ == '__main__': + request_data = ProjectInfoExtractionModel( + prompt="海边派对主题的衬衫设计" + ) + service = ProjectInfoExtraction(request_data) + print(service.get_result()) diff --git a/app/service/project_info_extraction/service_generate_brand_info.py b/app/service/project_info_extraction/service_generate_brand_info.py deleted file mode 100644 index 8ee7bcd..0000000 --- a/app/service/project_info_extraction/service_generate_brand_info.py +++ /dev/null @@ -1,47 +0,0 @@ -from langchain.output_parsers import ResponseSchema, StructuredOutputParser -from langchain_community.chat_models import ChatTongyi -from langchain_core.prompts import PromptTemplate - -from app.schemas.project_info_extraction import ProjectInfoExtractionModel - - -class ProjectInfoExtraction: - def __init__(self, request_data): - # llm generate brand info init - self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") - - self.response_schemas = [ - ResponseSchema(name="project_name", description="project name."), - ResponseSchema(name="role", description="The target role of the project."), - ResponseSchema(name="gender", description="The gender targeted by the project."), - ResponseSchema(name="style", description="Project style.") - ] - self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) - self.format_instructions = self.output_parser.get_format_instructions() - self.prompt = PromptTemplate( - template="你是一个时装品牌的设计师助理。根据用户输入提取出project_name,role ,gender ,style ." - "gender部分请用以下:menswear,womenswear,childrenwear,如果全部都适用即all." - "如果没有以上内容,需要你根据用户输入随意发挥.\n{format_instructions}\n{question}", - input_variables=["question"], - partial_variables={"format_instructions": self.format_instructions} - ) - self._input = self.prompt.format_prompt(question=request_data.prompt) - - self.result_data = {} - - def get_result(self): - self.llm_extraction_project_info() - return self.result_data - - def llm_extraction_project_info(self): - output = self.model(self._input.to_messages()) - project_info = self.output_parser.parse(output.content) - self.result_data = project_info - - -if __name__ == '__main__': - request_data = ProjectInfoExtractionModel( - prompt="海边派对主题的系列设计" - ) - service = ProjectInfoExtraction(request_data) - print(service.get_result()) From fbe939ee22f64108f4921f357d98d92761bf68ab Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Mon, 19 May 2025 13:10:51 +0800 Subject: [PATCH 06/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20=E6=96=B0=E5=A2=9E=E6=96=87=E4=BB=B6=E4=B8=8A?= =?UTF-8?q?=E4=BC=A0=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_extraction_project_info.py | 8 +++++++- app/schemas/project_info_extraction.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/app/api/api_extraction_project_info.py b/app/api/api_extraction_project_info.py index ad55552..798126a 100644 --- a/app/api/api_extraction_project_info.py +++ b/app/api/api_extraction_project_info.py @@ -19,7 +19,13 @@ def extraction_project_info(request_data: ProjectInfoExtractionModel): 示例参数: { - "prompt": "海边派对主题的系列设计" + "prompt": "海边派对主题的系列设计", + "image_list": [ + "https://www.minio-api.aida.com.hk/test/test123.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250519%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250519T050808Z&X-Amz-Expires=7200&X-Amz-SignedHeaders=host&X-Amz-Signature=296ff07cc4692d0a26ddffac582064f036494af343389fe60193dc2c5dc883ff" + ], + "file_list": [ + "" + ] } """ try: diff --git a/app/schemas/project_info_extraction.py b/app/schemas/project_info_extraction.py index 90def8b..6f579dd 100644 --- a/app/schemas/project_info_extraction.py +++ b/app/schemas/project_info_extraction.py @@ -3,3 +3,5 @@ from pydantic import BaseModel class ProjectInfoExtractionModel(BaseModel): prompt: str + image_list: list + file_list: list From f234ae29ffb5beead273ccc6205c5a1bfcd802db Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Mon, 2 Jun 2025 10:01:16 +0800 Subject: [PATCH 07/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20?= =?UTF-8?q?=20minio=E9=85=8D=E7=BD=AE=E6=9B=B4=E6=96=B0=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index ab53e0f..4930e97 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -43,7 +43,7 @@ JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/ settings = Settings() # minio 配置 -MINIO_URL = "www.minio.aida.com.hk:12024" +MINIO_URL = "www.minio-api.aida.com.hk" MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' MINIO_SECURE = True From 90f9879edb53f5bbf0b264a621d3e20d3d2b7c0e Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:55:55 +0800 Subject: [PATCH 08/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20batch=20generate=20product=20=E5=85=A5=E5=8F=82?= =?UTF-8?q?=E5=9B=9E=E5=8F=82=E4=BF=AE=E6=94=B9=20fix=EF=BC=88=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D?= =?UTF-8?q?=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 13 +- .../service_batch_generate_product_image.py | 139 ++++++++++++------ 2 files changed, 105 insertions(+), 47 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 99d1836..a989f2e 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel @@ -43,13 +45,18 @@ class GenerateRelightImageModel(BaseModel): """ -class BatchGenerateProductImageModel(BaseModel): +class ProductItemModel(BaseModel): tasks_id: str + image_strength: float prompt: str image_url: str - image_strength: float product_type: str - batch_size: int + + +class BatchGenerateProductImageModel(BaseModel): + batch_tasks_id: str + user_id: str + batch_data_list: List[ProductItemModel] class BatchGenerateRelightImageModel(BaseModel): diff --git a/app/service/generate_batch_image/service_batch_generate_product_image.py b/app/service/generate_batch_image/service_batch_generate_product_image.py index f09fbd5..46a5695 100644 --- a/app/service/generate_batch_image/service_batch_generate_product_image.py +++ b/app/service/generate_batch_image/service_batch_generate_product_image.py @@ -19,7 +19,7 @@ from celery import Celery from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import BatchGenerateProductImageModel +from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.utils.oss_client import oss_get_image @@ -35,38 +35,38 @@ category = "product_image" @celery_app.task def batch_generate_product(batch_request_data): + batch_size = len(batch_request_data['batch_data_list']) logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}") - tasks_id = batch_request_data['tasks_id'] - user_id = tasks_id.rsplit('-', 1)[1] - batch_size = batch_request_data['batch_size'] - image = pre_processing_image(batch_request_data['image_url']) - image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) - images = [image.astype(np.uint8)] * 1 + batch_tasks_id = batch_request_data['batch_tasks_id'] + user_id = batch_request_data['user_id'] + result_data_list = [] - prompts = [batch_request_data['prompt']] * 1 + for i, data in enumerate(batch_request_data['batch_data_list']): + tasks_id = data['tasks_id'] + image = pre_processing_image(data['image_url']) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + images = [image.astype(np.uint8)] * 1 + prompts = [data['prompt']] * 1 + if data['product_type'] == "single": + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) - if batch_request_data['product_type'] == "single": - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) - else: - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1)) - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") - input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_image_strength.set_data_from_numpy(image_strength_obj) - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_image_strength.set_data_from_numpy(image_strength_obj) + inputs = [input_text, input_image, input_image_strength] - inputs = [input_text, input_image, input_image_strength] - - image_url_list = [] - for i in range(batch_size): try: - if batch_request_data['product_type'] == "single": + if data['product_type'] == "single": result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) image = result.as_numpy("generated_cnet_image") else: @@ -77,7 +77,7 @@ def batch_generate_product(batch_request_data): if 'mask_list' in str(e): e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) + e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1)) e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype)) e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8") @@ -96,18 +96,29 @@ def batch_generate_product(batch_request_data): if isinstance(image_result, Image.Image): image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") - image_url_list.append(image_url) + data['product_img'] = image_url + result_data_list.append(data) else: image_url = image_result - if DEBUG is False: - if i + 1 < batch_size: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - else: - publish_status(tasks_id, f"OK", image_url_list) - logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") + data['product_img'] = image_url + result_data_list.append(data) + + # 发送每条结果 + if DEBUG: + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + else: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + + # 任务完成,发送所有数据结果 + if DEBUG: + print(result_data_list) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + else: + publish_status(batch_tasks_id, f"OK", result_data_list) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") def pre_processing_image(image_url): @@ -180,12 +191,52 @@ def publish_status(task_id, progress, result): if __name__ == '__main__': + # rd = BatchGenerateProductImageModel( + # tasks_id="123-15-51-89", + # image_strength=0.7, + # prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + # image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # product_type="overall", + # batch_size=20 + # ) + # batch_generate_product(rd.dict()) + # rd = { + # "user_id": "89", + # "batch_data_list": [ + # { + # "tasks_id": "A-123-15-51-89", + # "image_strength": 0.7, + # "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + # "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # "product_type": "overall", + # }, + # { + # "tasks_id": "B-123-15-51-89", + # "image_strength": 0.7, + # "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR", + # "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # "product_type": "overall", + # } + # ] + # } rd = BatchGenerateProductImageModel( - tasks_id="123-15-51-89", - image_strength=0.7, - prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", - image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", - product_type="overall", - batch_size=20 + batch_tasks_id="abcd", + user_id="89", + batch_data_list=[ + ProductItemModel( + tasks_id="123-5464", + image_strength=0.7, + product_type="overall", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + prompt="123" + ), + ProductItemModel( + tasks_id="123-5464123", + image_strength=0.7, + product_type="overall", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + prompt="123" + ) + ] ) batch_generate_product(rd.dict()) From 12bb12835126e89e5238e7b6530e5be7e3504a23 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:14:36 +0800 Subject: [PATCH 09/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20batch=20generate=20relight=20=E5=85=A5=E5=8F=82?= =?UTF-8?q?=E5=9B=9E=E5=8F=82=E4=BF=AE=E6=94=B9=20fix=EF=BC=88=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D?= =?UTF-8?q?=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 13 +- .../service_batch_generate_relight_image.py | 144 +++++++++++------- 2 files changed, 96 insertions(+), 61 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index a989f2e..7d1d864 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -45,6 +45,7 @@ class GenerateRelightImageModel(BaseModel): """ +# product任务子项 class ProductItemModel(BaseModel): tasks_id: str image_strength: float @@ -53,16 +54,24 @@ class ProductItemModel(BaseModel): product_type: str +# product批处理 集合 class BatchGenerateProductImageModel(BaseModel): batch_tasks_id: str user_id: str batch_data_list: List[ProductItemModel] -class BatchGenerateRelightImageModel(BaseModel): +# relight任务子项 +class RelightItemModel(BaseModel): tasks_id: str prompt: str image_url: str direction: str product_type: str - batch_size: int + + +# relight批处理集合 +class BatchGenerateRelightImageModel(BaseModel): + batch_tasks_id: str + user_id: str + batch_data_list: List[RelightItemModel] diff --git a/app/service/generate_batch_image/service_batch_generate_relight_image.py b/app/service/generate_batch_image/service_batch_generate_relight_image.py index 83a5701..d75b0a7 100644 --- a/app/service/generate_batch_image/service_batch_generate_relight_image.py +++ b/app/service/generate_batch_image/service_batch_generate_relight_image.py @@ -18,7 +18,7 @@ from celery import Celery from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import BatchGenerateRelightImageModel +from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.utils.oss_client import oss_get_image @@ -34,55 +34,58 @@ category = "relight_image" @celery_app.task def batch_generate_relight(batch_request_data): + batch_size = len(batch_request_data['batch_data_list']) logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}") + batch_tasks_id = batch_request_data['batch_tasks_id'] + user_id = batch_request_data['user_id'] + result_data_list = [] negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - direction = batch_request_data['direction'] seed = "1" - prompt = batch_request_data['prompt'] - product_type = batch_request_data['product_type'] - image_url = batch_request_data['image_url'] - image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2") - tasks_id = batch_request_data['tasks_id'] - user_id = tasks_id.rsplit('-', 1)[1] - batch_size = batch_request_data['batch_size'] - prompts = [prompt] * 1 - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - image = cv2.resize(image, (512, 768)) - images = [image.astype(np.uint8)] * 1 - seeds = [seed] * 1 - nagetive_prompts = [negative_prompt] * 1 - directions = [direction] * 1 + for i, data in enumerate(batch_request_data['batch_data_list']): + direction = data['direction'] - if product_type == 'single': - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) - seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) - direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) - else: - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) - seed_obj = np.array(seeds, dtype="object").reshape((1)) - direction_obj = np.array(directions, dtype="object").reshape((1)) - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") - input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) - input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) - input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + prompt = data['prompt'] + product_type = data['product_type'] + image_url = data['image_url'] + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2") + tasks_id = data['tasks_id'] - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_natext.set_data_from_numpy(na_text_obj) - input_seed.set_data_from_numpy(seed_obj) - input_direction.set_data_from_numpy(direction_obj) + prompts = [prompt] * 1 + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * 1 + seeds = [seed] * 1 + nagetive_prompts = [negative_prompt] * 1 + directions = [direction] * 1 - inputs = [input_text, input_natext, input_image, input_seed, input_direction] - image_url_list = [] - for i in range(batch_size): + if product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) + + inputs = [input_text, input_natext, input_image, input_seed, input_direction] try: - if batch_request_data['product_type'] == "single": + if data['product_type'] == "single": result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) image = result.as_numpy("generated_relight_image") else: @@ -121,18 +124,29 @@ def batch_generate_relight(batch_request_data): logger.error(e) if isinstance(image_result, Image.Image): image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") - image_url_list.append(image_url) + data['relight_img'] = image_url + + result_data_list.append(data) else: image_url = image_result - if DEBUG is False: - if i + 1 < batch_size: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - else: - publish_status(tasks_id, f"OK", image_url_list) - logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") + data['relight_img'] = image_url + result_data_list.append(data) + + # 发送每条结果 + if DEBUG: + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + else: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + # 任务完成,发送所有数据结果 + if DEBUG: + print(result_data_list) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + else: + publish_status(batch_tasks_id, f"OK", result_data_list) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") def publish_status(task_id, progress, result): @@ -151,12 +165,24 @@ def publish_status(task_id, progress, result): if __name__ == '__main__': rd = BatchGenerateRelightImageModel( - tasks_id="123-89", - # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - prompt="Colorful black", - image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png', - direction="Right Light", - product_type="overall", - batch_size=10 + batch_tasks_id="abcd", + user_id="89", + batch_data_list=[ + RelightItemModel( + tasks_id="123-5464", + product_type="overall", + image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + prompt="Colorful black", + direction="Right Light", + ), + RelightItemModel( + tasks_id="123-5464123", + product_type="overall", + image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + direction="Right Light", + prompt="Colorful black", + ) + ] ) + batch_generate_relight(rd.dict()) From be2d1db165fba481d0fa92afdda13b6a38690226 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Fri, 6 Jun 2025 16:50:55 +0800 Subject: [PATCH 10/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20b?= =?UTF-8?q?atch=20generate=20product=20/=20relight=20=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E5=BC=82=E5=B8=B8=E4=BF=AE=E5=A4=8D=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_batch_image/service.py | 8 +++---- .../service_batch_generate_relight_image.py | 22 ++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/app/service/generate_batch_image/service.py b/app/service/generate_batch_image/service.py index 2279382..6d007c8 100644 --- a/app/service/generate_batch_image/service.py +++ b/app/service/generate_batch_image/service.py @@ -6,15 +6,15 @@ from app.service.generate_batch_image.service_batch_pose_transform import batch_ async def start_product_batch_generate(data): generate_clothes_task = batch_generate_product.delay(data.dict()) print(generate_clothes_task) - product_publish_status(data.tasks_id, f"0/{data.batch_size}", "") - return {"task_id": data.tasks_id, "state": generate_clothes_task.state} + product_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "") + return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state} async def start_relight_batch_generate(data): generate_clothes_task = batch_generate_relight.delay(data.dict()) print(generate_clothes_task) - relight_publish_status(data.tasks_id, f"0/{data.batch_size}", "") - return {"task_id": data.tasks_id, "state": generate_clothes_task.state} + relight_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "") + return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state} async def start_pose_transform_batch_generate(data): diff --git a/app/service/generate_batch_image/service_batch_generate_relight_image.py b/app/service/generate_batch_image/service_batch_generate_relight_image.py index d75b0a7..0a90646 100644 --- a/app/service/generate_batch_image/service_batch_generate_relight_image.py +++ b/app/service/generate_batch_image/service_batch_generate_relight_image.py @@ -184,5 +184,25 @@ if __name__ == '__main__': ) ] ) - batch_generate_relight(rd.dict()) + # X = { + # "batch_tasks_id": "abcd", + # "user_id": "89", + # "batch_data_list": [ + # { + # "tasks_id": "123-5464", + # "product_type": "overall", + # "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + # "prompt": "Colorful black", + # "direction": "Right Light", + # }, + # { + # "tasks_id": "123-5464", + # "product_type": "overall", + # "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + # "prompt": "Colorful black", + # "direction": "Right Light", + # } + # + # ] + # } From e8cbb8569ac890de9369273c3208f1d7cf1009cd Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:04:27 +0800 Subject: [PATCH 11/11] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20b?= =?UTF-8?q?atch=20generate=20product=20/=20relight=20mq=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E6=9B=B4=E6=96=B0=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_batch_generate_product_image.py | 4 ++-- .../service_batch_generate_relight_image.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/generate_batch_image/service_batch_generate_product_image.py b/app/service/generate_batch_image/service_batch_generate_product_image.py index 46a5695..570354a 100644 --- a/app/service/generate_batch_image/service_batch_generate_product_image.py +++ b/app/service/generate_batch_image/service_batch_generate_product_image.py @@ -108,8 +108,8 @@ def batch_generate_product(batch_request_data): logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") else: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + publish_status(tasks_id, f"{i + 1}/{batch_size}", data) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") # 任务完成,发送所有数据结果 if DEBUG: diff --git a/app/service/generate_batch_image/service_batch_generate_relight_image.py b/app/service/generate_batch_image/service_batch_generate_relight_image.py index 0a90646..e75c0cc 100644 --- a/app/service/generate_batch_image/service_batch_generate_relight_image.py +++ b/app/service/generate_batch_image/service_batch_generate_relight_image.py @@ -137,8 +137,8 @@ def batch_generate_relight(batch_request_data): logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") else: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + publish_status(tasks_id, f"{i + 1}/{batch_size}", data) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") # 任务完成,发送所有数据结果 if DEBUG: print(result_data_list)