diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index d9b210c..5c15efe 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -34,13 +34,14 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): ] """ try: - logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") + for item in request_item: + logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -65,10 +66,11 @@ def category_recognition(request_item: list[CategoryRecognitionModel]): ] """ try: - logger.info(f"category_recognition request item is : @@@@@@:{request_item}") + for item in request_item: + logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}") service = CategoryRecognition(request_data=request_item) data = service.get_result() - logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"category_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index 6f3da16..c8bcf32 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel): } """ try: - logger.info(f"chat_robot request item is : @@@@@@:{request_data}") + logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}") data = chat(post_data=request_data) - logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"chat_robot Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_design.py b/app/api/api_design.py index ec618cc..5ce6096 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -66,22 +66,57 @@ def design(request_data: DesignModel): ], "path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png", "print": { - "IfSingle": false, - "location": [ - [ - 512.0, - 512.0 + "single": { + "location": [ + [ + 200.0, + 200.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png" + ], + "print_scale_list": [ + 1.0 ] - ], - "print_angle_list": [ - 0.0 - ], - "print_path_list": [ - "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" - ], - "print_scale_list": [ - 1.0 - ] + }, + "overall": { + "location": [ + [ + 512.0, + 512.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" + ], + "print_scale_list": [ + 1.0 + ] + }, + "element": { + "element_angle_list": [ + 0.0 + ], + "element_path_list": [ + "aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png" + ], + "element_scale_list": [ + 0.2731036750637755 + ], + "location": [ + [ + 228.63694825464364, + 406.4843844199667 + ] + ] + } }, "priority": 10, "resize_scale": [ @@ -102,9 +137,9 @@ def design(request_data: DesignModel): } """ try: - logger.info(f"design request item is : @@@@@@:{request_data.dict()}") + logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") data = generate(request_data=request_data) - logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"design response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -124,13 +159,13 @@ def get_progress(request_data: DesignProgressModel): } """ try: - logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}") + logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}") process_id = request_data.process_id r = Redis() data = r.read(key=process_id) if data is None: raise ValueError(f"No progress ID: {process_id}") - logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}") except Exception as e: logger.warning(f"get_progress Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -150,10 +185,10 @@ def model_process(request_data: ModelProgressModel): } """ try: - logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}") + logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}") data = model_transpose(image_path=request_data.model_path) - logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"model_process response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"model_process Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index f6946dc..f260e22 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel): } """ try: - logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") + logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}") server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) - logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"design response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 3f3646f..3dee667 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel @@ -38,7 +38,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun } """ try: - logger.info(f"generate_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -52,7 +52,7 @@ def generate_image(tasks_id: str): try: logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") data = generate_image_infer_cancel(tasks_id) - logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -78,7 +78,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ } """ try: - logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}") + logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateSingleLogoImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -92,7 +92,7 @@ def generate_single_logo_image(tasks_id: str): try: logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") data = generate_single_logo_cancel(tasks_id) - logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_single_logo_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -109,16 +109,21 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 + - **image_strength**: 生成强度,越低越接近原图 + - **product_type**: 输入single item 还是 overall item + 示例参数: { "tasks_id": "123-89", "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" + "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + "image_strength": 0.8, + "product_type": "overall" } """ try: - logger.info(f"generate_product_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateProductImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -132,7 +137,7 @@ def generate_product_image(tasks_id: str): try: logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") data = generate_product_image_cancel(tasks_id) - logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -143,22 +148,27 @@ def generate_product_image(tasks_id: str): @router.post("/generate_relight_image") -def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): +def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks): """ 创建一个具有以下参数的请求体: - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 + - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light + - **product_type**: 输入single item 还是 overall item + 示例参数: { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png" + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "direction": "Right Light", + "product_type": "overall" } """ try: - logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateRelightImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -172,7 +182,7 @@ def generate_relight_image(tasks_id: str): try: logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}") data = generate_relight_image_cancel(tasks_id) - logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index c7bcbcd..c227b07 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -27,7 +27,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") data = translate_to_en(request_data.text) - logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 82b58f4..ce853fd 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg } """ try: - logger.info(f"super_resolution request item is : @@@@@@:{request_item}") + logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}") service = SuperResolution(request_item) background_tasks.add_task(service.sr_result) except Exception as e: @@ -41,7 +41,7 @@ def super_resolution(tasks_id: str): try: logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}") data = infer_cancel(tasks_id) - logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"sr_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"sr_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_test.py b/app/api/api_test.py index 0ff521a..1271f95 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,8 +1,10 @@ +import json import logging -from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS -from fastapi import FastAPI, HTTPException +from fastapi import APIRouter +from fastapi import HTTPException + +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from app.schemas.response_template import ResponseModel logger = logging.getLogger() @@ -18,7 +20,7 @@ def test(id: int): "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, "local_oss_server": OSS } - logger.info(data) + logger.info(json.dumps(data)) if id == 1: raise HTTPException(status_code=404, detail="Item not found") diff --git a/app/core/config.py b/app/core/config.py index 09e48fc..a01a2c0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,14 +118,17 @@ GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") -# Generate Single Logo service config +# Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME = 'diffusion_ensemble_all' +GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' +GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' + GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 4f85002..3dd7cf8 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,9 +20,13 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + image_strength: float + product_type: str class GenerateRelightImageModel(BaseModel): tasks_id: str prompt: str image_url: str + direction: str + product_type: str diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py index 8ee8223..cabe372 100644 --- a/app/service/chat_robot/script/main.py +++ b/app/service/chat_robot/script/main.py @@ -1,26 +1,24 @@ +import json import logging - from langchain_community.chat_models import ChatTongyi from loguru import logger from langchain.agents import Tool -from langchain_community.utilities import SerpAPIWrapper +from langchain.callbacks import FileCallbackHandler +from langchain.utilities import SerpAPIWrapper from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder from langchain.schema import SystemMessage, AIMessage -# from langchain_community.chat_models import ChatOpenAI -# from langchain_community.llms import OpenAI +from langchain.chat_models import ChatOpenAI +from langchain.llms.openai import OpenAI from langchain.callbacks import FileCallbackHandler from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler from app.service.chat_robot.script.database import CustomDatabase +from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX from app.service.chat_robot.script.service import CallQWen from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) -from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory -from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool from app.core.config import * -import os - # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" # log callbacks @@ -138,5 +136,5 @@ def chat(post_data): 'completion_tokens': final_outputs['completion_tokens'], 'response_type': final_outputs["response_type"] } - logging.info(api_response) + logging.info(json.dumps(api_response)) return api_response diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py index 5adcc70..f9f9561 100644 --- a/app/service/design/items/clothing.py +++ b/app/service/design/items/clothing.py @@ -36,7 +36,9 @@ class Clothing(object): position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), - gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] + ) layer.insert(front_layer) @@ -51,7 +53,8 @@ class Clothing(object): position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), - gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] ) layer.insert(back_layer) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index a738455..224e753 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -88,98 +88,113 @@ class PrintPainting(object): # @ RunTime def __call__(self, result): + single_print = result['print']['single'] + overall_print = result['print']['overall'] + element_print = result['print']['element'] + result['single_image'] = None + result['print_image'] = None + if overall_print['print_path_list']: + painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} + result['print_image'] = result['pattern_image'] + if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) - if "location" not in result['print'].keys(): - result['print']["location"] = [[0, 0]] - elif result['print']["location"] == [] or result['print']["location"] is None: - result['print']["location"] = [[0, 0]] - if result['print']['IfSingle']: - if len(result['print']['print_path_list']) > 0: - print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - # print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8) - for i in range(len(result['print']['print_path_list'])): - image, image_mode = self.read_image(result['print']['print_path_list'][i]) - if image_mode == "RGBA": - new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i])) + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] - mask = image.split()[3] - resized_source = image.resize(new_size) - resized_source_mask = mask.resize(new_size) + if single_print['print_path_list']: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(single_print['print_path_list'])): + image, image_mode = self.read_image(single_print['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) - rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i]) + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) - source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) - source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i]) - source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) - source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) - print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) - mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 else: - mask = self.get_mask_inv(image) - mask = np.expand_dims(mask, axis=2) - mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) - mask = cv2.bitwise_not(mask) - # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) - x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1]) + start_x = x - image_x = print_background.shape[1] - image_y = print_background.shape[0] - print_x = rotate_image.shape[1] - print_y = rotate_image.shape[0] + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y - # 有bug - # if x + print_x > image_x: - # rotate_image = rotate_image[:, :x + print_x - image_x] - # rotate_mask = rotate_mask[:, :x + print_x - image_x] - # # - # if y + print_y > image_y: - # rotate_image = rotate_image[:y + print_y - image_y] - # rotate_mask = rotate_mask[:y + print_y - image_y] + # ------------------ + # 如果print-size大于image-size 则需要裁剪print - # 不能是并行 - # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 - # 先挪 再判断 最后裁剪 + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] - # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 - if x <= 0: - rotate_image = rotate_image[:, -x:] - rotate_mask = rotate_mask[:, -x:] - start_x = x = 0 - else: - start_x = x + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] - if y <= 0: - rotate_image = rotate_image[-y:, :] - rotate_mask = rotate_mask[-y:, :] - start_y = y = 0 - else: - start_y = y + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) - # ------------------ - # 如果print-size大于image-size 则需要裁剪print - - if x + print_x > image_x: - rotate_image = rotate_image[:, :image_x - x] - rotate_mask = rotate_mask[:, :image_x - x] - - if y + print_y > image_y: - rotate_image = rotate_image[:image_y - y, :] - rotate_mask = rotate_mask[:image_y - y, :] - - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) - - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image - mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) - print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) @@ -197,54 +212,27 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - else: - painting_dict = {} - painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] - # no print - if len(result['print_dict']['print_path_list']) == 0 or not self.print_flag: - result['print_image'] = result['pattern_image'] - # print - else: - if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: - painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) - painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) - painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) - - # resize 到sketch大小 - painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) - painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) - else: - painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) - result['print_image'] = self.printpaint(result, painting_dict, print_=True) - result['final_image'] = result['print_image'] - canvas = np.full_like(result['final_image'], 255) - temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) - tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) - temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) - tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) - result['single_image'] = cv2.add(tmp1, tmp2) - - if "element" in result.keys(): + if element_print['element_path_list']: print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) - for i in range(len(result['element']['element_path_list'])): - image, image_mode = self.read_image(result['element']['element_path_list'][i]) + for i in range(len(element_print['element_path_list'])): + image, image_mode = self.read_image(element_print['element_path_list'][i]) if image_mode == "RGBA": - new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i])) + new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) mask = image.split()[3] resized_source = image.resize(new_size) resized_source_mask = mask.resize(new_size) - rotated_resized_source = resized_source.rotate(-result['element']['element_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(-result['element']['element_angle_list'][i]) + rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i]) source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) - source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source) - source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask) + source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask) print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) @@ -255,10 +243,10 @@ class PrintPainting(object): mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) mask = cv2.bitwise_not(mask) # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) - x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1]) + x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) image_x = print_background.shape[1] image_y = print_background.shape[0] @@ -352,18 +340,18 @@ class PrintPainting(object): return print_background - def painting_collection(self, painting_dict, result, print_trigger=False): + def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): if print_trigger: - print_ = self.get_print(result['print_dict']) - painting_dict['Trigger'] = not print_['IfSingle'] - painting_dict['location'] = print_['location'] if 'location' in print_.keys() else None + print_ = self.get_print(print_dict) + painting_dict['Trigger'] = not is_single + painting_dict['location'] = print_['location'] single_mask_inv_print = self.get_mask_inv(print_['image']) dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) - if not print_['IfSingle']: + if not is_single: self.random_seed = random.randint(0, 1000) # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 - if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: + if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) else: @@ -458,19 +446,11 @@ class PrintPainting(object): @staticmethod def get_print(print_dict): - if not 'print_scale_list' in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: print_dict['scale'] = 0.3 else: print_dict['scale'] = print_dict['print_scale_list'][0] - if not 'IfSingle' in print_dict.keys(): - print_dict['IfSingle'] = False - - # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - # data_bytes = BytesIO(data.read()) - # image = Image.open(data_bytes) - # image_mode = image.mode - bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] object_name = print_dict['print_path_list'][0].split("/", 1)[1] image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL") @@ -480,13 +460,6 @@ class PrintPainting(object): new_background.paste(image, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) - - # file = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]).data - # print_dict['image'] = cv2.imdecode(np.fromstring(file, np.uint8), 1) - - # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) - # return image - return print_dict def crop_image(self, image, image_size_h, image_size_w, location, print_shape): diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index 155347a..1e06712 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -71,6 +71,11 @@ class Split(object): result["back_image_url"] = None result["back_mask_url"] = None result['back_mask_image'] = None + + # 创建中间图层 + result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask']) + result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) + _, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') return result except Exception as e: logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/service.py b/app/service/design/service.py index 0ba5e72..54cb45b 100644 --- a/app/service/design/service.py +++ b/app/service/design/service.py @@ -1,10 +1,10 @@ +import concurrent.futures + from app.core.config import PRIORITY_DICT from app.service.design.core.layer import Layer from app.service.design.items import build_item from app.service.design.utils.redis_utils import Redis from app.service.design.utils.synthesis_item import synthesis, synthesis_single -import concurrent.futures - from app.service.utils.decorator import RunTime @@ -96,6 +96,7 @@ def process_object(cfg, process_id, total): 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", 'mask_url': lay['mask_url'], 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, # 'image': lay['image'], # 'mask_image': lay['mask_image'], diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 1bc1c91..dac211c 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -40,16 +40,17 @@ class GenerateImage: if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) - self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - self.prompt = request_data.prompt + self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category + if self.category == "sketch": + self.prompt = f"{self.category},{self.prompt}" self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index dcdf09f..5ea6f83 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,17 +7,17 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np from PIL import Image, ImageOps -from minio import Minio from tritonclient.utils import np_to_triton_dtype + from app.core.config import * from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image @@ -37,7 +37,9 @@ class GenerateProductImage: self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" + self.image_strength = request_data.image_strength self.batch_size = 1 + self.product_type = request_data.product_type self.prompt = request_data.prompt self.image, self.image_size = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id @@ -53,7 +55,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" @@ -72,17 +77,31 @@ class GenerateProductImage: self.image = cv2.resize(self.image, (512, 768)) images = [self.image.astype(np.uint8)] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape(1) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + if self.product_type == "single": + text_obj = np.array(prompts, dtype="object").reshape(-1, 1) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + else: + text_obj = np.array(prompts, dtype="object").reshape(1) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + + # 假设 prompts、images 和 self.image_strength 已经定义 input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image] + inputs = [input_text, input_image, input_image_strength] + input_image_strength.set_data_from_numpy(image_strength_obj) + + if self.product_type == "single": + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() @@ -117,24 +136,39 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 原始图片的尺寸 + width, height = image.size - # resize 图片内尺寸 并贴到768-512的纯白图像上 - target_height = 768 - target_width = 512 - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 计算长宽比为 3:2 的新尺寸 + desired_ratio = 2 / 3 + current_ratio = width / height + + if current_ratio > desired_ratio: + # 原始图片更宽,需要在上下添加 padding + new_width = width + new_height = int(width / desired_ratio) + else: + # 原始图片更高或者长宽比已经为 3:2 + new_height = height + new_width = int(height * desired_ratio) + + # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 + pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + + # 将原始图片粘贴到新的画布中心 + left = (new_width - width) // 2 + top = (new_height - height) // 2 + pad_image.paste(image, (left, top)) + + # 将画布 resize 成宽度 500,长度 750 + resized_image = pad_image.resize((500, 750)) + image_size = (512, 768) + + if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) + background = Image.new("RGB", image_size, (255, 255, 255)) # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) + background.paste(resized_image, mask=resized_image.split()[3]) image = np.array(background) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image, image_size @@ -143,9 +177,11 @@ def pre_processing_image(image_url): if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", - prompt="", - # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + # prompt="", + image_strength=0.9, + prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + product_type="overall" ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index ca32c73..e0729ba 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -7,16 +7,15 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np -from PIL import Image, ImageOps -from minio import Minio +from PIL import Image from tritonclient.utils import np_to_triton_dtype from app.core.config import * @@ -39,8 +38,9 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "1" + self.product_type = request_data.product_type self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.direction = "Right Light" + self.direction = request_data.direction self.image_url = request_data.image_url self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id @@ -56,7 +56,11 @@ class GenerateRelightImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == 'single': + image = result.as_numpy("generated_relight_image") + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -79,11 +83,18 @@ class GenerateRelightImage: nagetive_prompts = [self.negative_prompt] * self.batch_size directions = [self.direction] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) - seed_obj = np.array(seeds, dtype="object").reshape((1)) - direction_obj = np.array(directions, dtype="object").reshape((1)) + if self.product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") @@ -98,8 +109,11 @@ class GenerateRelightImage: input_direction.set_data_from_numpy(direction_obj) inputs = [input_text, input_natext, input_image, input_seed, input_direction] + if self.product_type == 'single': + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) - ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() @@ -137,7 +151,9 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", prompt="Colorful black", - image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png' + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + direction="Right Light", + product_type="single" ) server = GenerateRelightImage(rd) print(server.get_result()) diff --git a/logging_env.py b/logging_env.py index d618e37..08873b0 100644 --- a/logging_env.py +++ b/logging_env.py @@ -1,51 +1,51 @@ from app.core.config import LOGS_PATH LOGGER_CONFIG_DICT = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"} + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'} }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "level": "INFO", - "formatter": "simple", - "stream": "ext://sys.stdout", + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'stream': 'ext://sys.stdout', }, - "info_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "INFO", - "formatter": "simple", - "filename": f"{LOGS_PATH}info.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'info_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}info.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, - "error_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "ERROR", - "formatter": "simple", - "filename": f"{LOGS_PATH}error.log", - "maxBytes": 10485760, - "backupCount": 20, - "encoding": "utf8", + 'error_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'ERROR', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}error.log', + 'maxBytes': 10485760, + 'backupCount': 20, + 'encoding': 'utf8', }, - "debug_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "DEBUG", - "formatter": "simple", - "filename": f"{LOGS_PATH}debug.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'debug_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'DEBUG', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}debug.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, }, - "loggers": { - "my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"} + 'loggers': { + 'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'} }, - "root": { - "level": "INFO", - "handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"], + 'root': { + 'level': 'INFO', + 'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'], }, }