diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index fdecfa8..5c15efe 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,9 +1,12 @@ import json import logging -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from app.core.config import DEBUG from app.schemas.attribute_retrieve import * -from app.service.attribute.config import const +from app.schemas.response_template import ResponseModel +from app.service.attribute.config import const, local_debug_const from app.service.attribute.service_att_recognition import AttributeRecognition from app.service.attribute.service_category_recognition import CategoryRecognition @@ -12,34 +15,63 @@ logger = logging.getLogger() # 属性识别 -@router.post("/attribute_recognition") +@router.post("/attribute_recognition", response_model=ResponseModel) def attribute_recognition(request_item: list[AttributeRecognitionModel]): + """ + 获取sketch的属性,collar sleeve_length 等等 + 创建一个具有以下参数的请求体: + - **category**: sketch的类别 ,Dress + - **colony**: 服装类别,男装或女装 + - **sketch_img_url**: 被提取属性的S3或minio url地址 + + 示例参数: + [ + { + "category": "Dress", + "colony": "Female", + "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" + } + ] + """ try: - service = AttributeRecognition(const=const, request_data=request_item) + for item in request_item: + logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") + if DEBUG: + service = AttributeRecognition(const=local_debug_const, request_data=request_item) + else: + service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - code = 200 - message = "access" - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: - code = 400 - message = e - data = e logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") - return {"code": code, "message": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data={"list": data}) # 类别识别 @router.post("/category_recognition") def category_recognition(request_item: list[CategoryRecognitionModel]): + """ + 获取sketch的类别,dress blouse 等等 + 创建一个具有以下参数的请求体: + - **colony**: 服装类别,male或Female + - **sketch_img_url**: 被提取sketch类别的S3或minio url地址 + + 示例参数: + [ + { + "colony": "Female", + "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" + } + ] + """ try: + for item in request_item: + logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}") service = CategoryRecognition(request_data=request_item) data = service.get_result() - code = 200 - message = "access" - logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: - code = 400 - message = e - data = e logger.warning(f"category_recognition Run Exception @@@@@@:{e}") - return {"code": code, "message": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py new file mode 100644 index 0000000..c8bcf32 --- /dev/null +++ b/app/api/api_chat_robot.py @@ -0,0 +1,39 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.chat_robot import ChatRobotModel +from app.schemas.response_template import ResponseModel +from app.service.chat_robot.script.main import chat + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/chat_robot") +def chat_robot(request_data: ChatRobotModel): + """ + 对话机器人 + 创建一个具有以下参数的请求体: + - **gender**: 服装类别 + - **message**: 消息 + - **session_id**: 会话id + - **user_id**: 用户id + + 示例参数: + { + "gender": "male", + "message": "你好", + "session_id": "string-89", + "user_id": 89 + } + """ + try: + logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = chat(post_data=request_data) + logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"chat_robot Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_design.py b/app/api/api_design.py new file mode 100644 index 0000000..5ce6096 --- /dev/null +++ b/app/api/api_design.py @@ -0,0 +1,195 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel +from app.schemas.response_template import ResponseModel +from app.service.design.model_process_service import model_transpose +from app.service.design.service import generate +from app.service.design.utils.redis_utils import Redis + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/design") +def design(request_data: DesignModel): + """ + 创建一个具有以下参数的请求体: + 示例参数: + { + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": true, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 255303, + "color": "139 148 156", + "image_id": 95159, + "offset": [ + 0, + 0 + ], + "path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png", + "print": { + "single": { + "location": [ + [ + 200.0, + 200.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png" + ], + "print_scale_list": [ + 1.0 + ] + }, + "overall": { + "location": [ + [ + 512.0, + 512.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" + ], + "print_scale_list": [ + 1.0 + ] + }, + "element": { + "element_angle_list": [ + 0.0 + ], + "element_path_list": [ + "aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png" + ], + "element_scale_list": [ + 0.2731036750637755 + ], + "location": [ + [ + 228.63694825464364, + 406.4843844199667 + ] + ] + } + }, + "priority": 10, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "type": "Body" + } + ] + } + ], + "process_id": "89" + } + """ + try: + logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = generate(request_data=request_data) + logger.info(f"design response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"design Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) + + +@router.post('/get_progress') +def get_progress(request_data: DesignProgressModel): + """ + 获取design 进度 + 创建一个具有以下参数的请求体: + - **process_id**: 进度id + + 示例参数: + { + "process_id": "6878547032381675" + } + """ + try: + logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}") + process_id = request_data.process_id + r = Redis() + data = r.read(key=process_id) + if data is None: + raise ValueError(f"No progress ID: {process_id}") + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}") + except Exception as e: + logger.warning(f"get_progress Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) + + +@router.post('/model_process') +def model_process(request_data: ModelProgressModel): + """ + 获取模特图片预处理 + 创建一个具有以下参数的请求体: + - **model_path**: 模特图片的minio或s3 url地址 + + 示例参数: + { + "model_path": "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.jpg" + } + """ + try: + logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}") + + data = model_transpose(image_path=request_data.model_path) + logger.info(f"model_process response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"model_process Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py new file mode 100644 index 0000000..f260e22 --- /dev/null +++ b/app/api/api_design_pre_processing.py @@ -0,0 +1,40 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.pre_processing import DesignPreProcessingModel +from app.schemas.response_template import ResponseModel +from app.service.design_pre_processing.service import DesignPreprocessing + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/design_pre_processing") +def design_pre_processing(request_data: DesignPreProcessingModel): + """ + design 预处理 获取sketch的基本信息 + 创建一个具有以下参数的请求体: + - **sketches**: sketch url等信息 + + 示例参数: + { + "sketches": [ + { + "image_category": "dress", + "image_id": "107903", + "image_url": "aida-sys-image/images/female/dress/0628000000.jpg" + } + ] + } + """ + try: + logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}") + server = DesignPreprocessing() + data = server.pipeline(image_list=request_data.sketches) + logger.info(f"design response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"design Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 78f3a66..3dee667 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -1,28 +1,189 @@ +import json import logging -from fastapi import APIRouter, BackgroundTasks -from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.service import GenerateImage, infer_cancel + +from fastapi import APIRouter, BackgroundTasks, HTTPException + +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.response_template import ResponseModel +from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel +from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel +from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel router = APIRouter() logger = logging.getLogger() +'''generate image''' + @router.post("/generate_image") def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 图生图的输入,minio或S3 url 地址 + - **mode**: 生成模式,img2img或者txt2img + - **category**: 生成图片的类别,sketch print 等等 + - **gender**: 生成sketch专用,服装类别 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", + "mode": "img2img", + "category": "sketch", + "gender": "male" + } + """ try: - logger.info(f"request data ### : {request_item}") + logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateImage(request_item) background_tasks.add_task(service.get_result) - code = 200 - message = "access" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "message": message} + logger.warning(f"generate_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() -@router.get("/generate_cancel/{tasks_id}>") -def generate_image(tasks_id): - result = infer_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} +@router.get("/generate_cancel/{tasks_id}") +def generate_image(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_image_infer_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + +'''single logo''' + + +@router.post("/generate_single_logo") +def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **seed**: 固定的prompt和固定的seed 每次的生成结果都是一样的 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "an apple", + "seed": "2" + } + """ + try: + logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateSingleLogoImage(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_single_logo Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_single_logo_cancel/{tasks_id}") +def generate_single_logo_image(tasks_id: str): + try: + logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") + data = generate_single_logo_cancel(tasks_id) + logger.info(f"generate_single_logo_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + +'''product image''' + + +@router.post("/generate_product_image") +def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + - **image_strength**: 生成强度,越低越接近原图 + - **product_type**: 输入single item 还是 overall item + + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + "image_strength": 0.8, + "product_type": "overall" + } + """ + try: + logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateProductImage(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_product_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_product_image_cancel_cancel/{tasks_id}") +def generate_product_image(tasks_id: str): + try: + logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") + data = generate_product_image_cancel(tasks_id) + logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + +'''relight image''' + + +@router.post("/generate_relight_image") +def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light + - **product_type**: 输入single item 还是 overall item + + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "direction": "Right Light", + "product_type": "overall" + } + """ + try: + logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateRelightImage(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_relight_image_cancel_cancel/{tasks_id}") +def generate_relight_image(tasks_id: str): + try: + logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}") + data = generate_relight_image_cancel(tasks_id) + logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py new file mode 100644 index 0000000..c227b07 --- /dev/null +++ b/app/api/api_prompt_generation.py @@ -0,0 +1,34 @@ +import json +import logging +import time + +from fastapi import APIRouter, HTTPException + +from app.schemas.prompt_generation import PromptGenerationImageModel +from app.schemas.response_template import ResponseModel +from app.service.prompt_generation.chatgpt_for_translation import translate_to_en + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/translateToEN") +def prompt_generation(request_data: PromptGenerationImageModel): + """ + 翻译prompt接口 + 创建一个具有以下参数的请求体: + - **text**: 待翻译语句 + + 示例参数: + { + "text": "你好" + } + """ + try: + logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") + data = translate_to_en(request_data.text) + logger.info(f"prompt_generation response @@@@@@:{data}") + except Exception as e: + logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index 2513204..c2bd2d2 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -4,6 +4,11 @@ from app.api import api_test from app.api import api_super_resolution from app.api import api_generate_image from app.api import api_attribute_retrieve +from app.api import api_design +from app.api import api_chat_robot +from app.api import api_prompt_generation +from app.api import api_design_pre_processing + router = APIRouter() @@ -11,3 +16,7 @@ router.include_router(api_test.router, tags=["test"], prefix="/test") router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api") router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api") router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api") +router.include_router(api_design.router, tags=['design'], prefix="/api") +router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") +router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") +router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 63f4498..ce853fd 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -1,8 +1,9 @@ import json import logging -from fastapi import APIRouter, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, HTTPException +from app.schemas.response_template import ResponseModel from app.schemas.super_resolution import SuperResolutionModel from app.service.super_resolution.service import SuperResolution, infer_cancel @@ -12,19 +13,36 @@ logger = logging.getLogger() @router.post("/super_resolution") def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **sr_image_url**: 超分图片的minio或s3 url地址 + - **sr_xn**: 超分的倍数,只接受2或4 + - **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果 + + 示例参数: + { + "sr_image_url": "aida-sys-image/images/female/blouse/0628000098.jpg", + "sr_xn": 2, + "sr_tasks_id": "12341556-89" + } + """ try: + logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}") service = SuperResolution(request_item) background_tasks.add_task(service.sr_result) - code = 200 - message = "access" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "message": message} + logger.warning(f"super_resolution Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() @router.get("/sr_cancel/{tasks_id}>") -def super_resolution(tasks_id): - result = infer_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} +def super_resolution(tasks_id: str): + try: + logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}") + data = infer_cancel(tasks_id) + logger.info(f"sr_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"sr_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_test.py b/app/api/api_test.py index 63ef1aa..1271f95 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,13 +1,27 @@ +import json import logging + from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES +from fastapi import HTTPException + +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS +from app.schemas.response_template import ResponseModel logger = logging.getLogger() router = APIRouter() -@router.get("") -def test(): - logger.info(SR_RABBITMQ_QUEUES) - logger.info("test") - return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES} +@router.get("{id}") +def test(id: int): + data = { + "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, + "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, + "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, + "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "local_oss_server": OSS + } + logger.info(json.dumps(data)) + if id == 1: + raise HTTPException(status_code=404, detail="Item not found") + + return ResponseModel(data=data) diff --git a/app/core/config.py b/app/core/config.py index 0671754..8b94834 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,15 +19,16 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "-dev" # 开发环境 @@ -41,6 +42,11 @@ MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' MINIO_SECURE = True +# S3 配置 +S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ" +S3_AWS_SECRET_ACCESS_KEY = "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" +S3_REGION_NAME = "ap-east-1" + # redis 配置 REDIS_HOST = "10.1.1.240" REDIS_PORT = "6379" @@ -55,12 +61,38 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_DB_HOST = "10.1.1.240" +MILVUS_URL = "http://10.1.1.240:19530" +MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" -MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" +# Mysql 配置 +DB_HOST = '18.167.251.121' # 数据库主机地址 +# DB_PORT = int( 33006) +DB_PORT = 33008 # 数据库端口 +DB_USERNAME = 'aida_con_python' # 数据库用户名 +DB_PASSWORD = '123456' # 数据库密码 +DB_NAME = 'aida' # 数据库库名 + +# openai +os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c" +OPENAI_STREAM = True +BUFFER_THRESHOLD = 6 # must be even number +SINGLE_TOKEN_THRESHOLD = 200 +TOKEN_THRESHOLD = 600 +OPENAI_TEMPERATURE = 0 + +# OPENAI_API_KEY = "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" +OPENAI_API_KEY = "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" +OPENAI_MODEL = "gpt-3.5-turbo-0613" +OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", } + # attribute service config ATT_TRITON_URL = "10.1.1.240:10000" @@ -77,6 +109,28 @@ GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" +# SLOGAN service config +SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") + +# Generate Single Logo service config +GSL_MODEL_URL = '10.1.1.240:10041' +GSL_MINIO_BUCKET = "aida-users" +GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' +GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") + +# Generate Product service config +GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") +GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' +GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' + +GPI_MODEL_URL = '10.1.1.240:10041' + +# Generate Single Logo service config +GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' +GRI_MODEL_URL = '10.1.1.240:10051' + # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' SEGMENTATION = { @@ -87,9 +141,13 @@ SEGMENTATION = { } # DESIGN config -DESIGN_MODEL_URL = '10.1.1.240:9000' - +DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" +KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', + 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') + +# DESIGN 预处理 +IF_DEBUG_SHOW = False # 优先级 PRIORITY_DICT = { @@ -117,4 +175,3 @@ PRIORITY_DICT = { 'bag_back': -98, 'earring_back': -99, } - diff --git a/app/main.py b/app/main.py index 07bd258..b085d7d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,14 +1,18 @@ import logging.config +from http.client import HTTPException +from fastapi.responses import JSONResponse +from fastapi import FastAPI, HTTPException, Request import uvicorn from fastapi import FastAPI from app.api.api_route import router from app.core.config import settings +from app.schemas.response_template import ResponseModel from logging_env import LOGGER_CONFIG_DICT - logging.config.dictConfig(LOGGER_CONFIG_DICT) +logging.getLogger("pika").setLevel(logging.WARNING) from starlette.middleware.cors import CORSMiddleware @@ -35,5 +39,15 @@ def get_application() -> FastAPI: app = get_application() + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + return JSONResponse( + status_code=exc.status_code, + content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict() + ) + + if __name__ == '__main__': uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/schemas/chat_robot.py b/app/schemas/chat_robot.py new file mode 100644 index 0000000..cebf74a --- /dev/null +++ b/app/schemas/chat_robot.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class ChatRobotModel(BaseModel): + gender: str + message: str + session_id: str + user_id: int diff --git a/app/schemas/design.py b/app/schemas/design.py new file mode 100644 index 0000000..edcc392 --- /dev/null +++ b/app/schemas/design.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel + + +# class BodyPointModel(BaseModel): +# waistband_right: list[int] +# hand_point_right: list[int] +# waistband_left: list[int] +# hand_point_left: list[int] +# shoulder_left: list[int] +# shoulder_right: list[int] +# +# +# class BasicModel(BaseModel): +# body_point: BodyPointModel +# layer_order: bool +# scale_bag: float +# scale_earrings: float +# self_template: bool +# single_overall: str +# switch_category: str +# body_path: str +# +# +# class PrintModel(BaseModel): +# if_single: bool +# print_path_list: list[str] +# +# +# class ItemModel(BaseModel): +# color: str +# image_id: str +# offset: list[int] +# path: str +# print: PrintModel +# resize_scale: float +# type: str +# +# +# class CollocationModel(BaseModel): +# basic: BasicModel +# item: list[ItemModel] +# +# +# class DesignModel(BaseModel): +# object: list[CollocationModel] +# process_id: str + +class DesignModel(BaseModel): + objects: list[dict] + process_id: str + + +class DesignProgressModel(BaseModel): + process_id: str + + +class ModelProgressModel(BaseModel): + model_path: str diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index b8f5441..3dd7cf8 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -8,3 +8,25 @@ class GenerateImageModel(BaseModel): mode: str category: str gender: str + + +class GenerateSingleLogoImageModel(BaseModel): + tasks_id: str + prompt: str + seed: str + + +class GenerateProductImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str + image_strength: float + product_type: str + + +class GenerateRelightImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str + direction: str + product_type: str diff --git a/app/schemas/pre_processing.py b/app/schemas/pre_processing.py new file mode 100644 index 0000000..47d9297 --- /dev/null +++ b/app/schemas/pre_processing.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DesignPreProcessingModel(BaseModel): + sketches: list[dict] diff --git a/app/schemas/prompt_generation.py b/app/schemas/prompt_generation.py new file mode 100644 index 0000000..195291b --- /dev/null +++ b/app/schemas/prompt_generation.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class PromptGenerationImageModel(BaseModel): + text: str diff --git a/app/schemas/response_template.py b/app/schemas/response_template.py new file mode 100644 index 0000000..b3b773c --- /dev/null +++ b/app/schemas/response_template.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel +from typing import Any, Optional + + +class ResponseModel(BaseModel): + code: int = 200 + msg: str = "OK!" + data: Optional[Any] = None diff --git a/app/service/attribute/config/const.py b/app/service/attribute/config/const.py index 24d9412..738e486 100644 --- a/app/service/attribute/config/const.py +++ b/app/service/attribute/config/const.py @@ -1,13 +1,13 @@ -top_description_list = ['service/attribute/config/descriptor/top/length.csv', - 'service/attribute/config/descriptor/top/type.csv', - 'service/attribute/config/descriptor/top/sleeve_length.csv', - 'service/attribute/config/descriptor/top/sleeve_shape.csv', - 'service/attribute/config/descriptor/top/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/top/neckline.csv', - 'service/attribute/config/descriptor/top/design.csv', - 'service/attribute/config/descriptor/top/opening_type.csv', - 'service/attribute/config/descriptor/top/silhouette.csv', - 'service/attribute/config/descriptor/top/collar.csv'] +top_description_list = ['app/service/attribute/config/descriptor/top/length.csv', + 'app/service/attribute/config/descriptor/top/type.csv', + 'app/service/attribute/config/descriptor/top/sleeve_length.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/top/neckline.csv', + 'app/service/attribute/config/descriptor/top/design.csv', + 'app/service/attribute/config/descriptor/top/opening_type.csv', + 'app/service/attribute/config/descriptor/top/silhouette.csv', + 'app/service/attribute/config/descriptor/top/collar.csv'] top_model_list = ['attr_retrieve_T_length', 'attr_retrieve_T_type', @@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length', ] bottom_description_list = [ - 'service/attribute/config/descriptor/bottom/subtype.csv', - 'service/attribute/config/descriptor/bottom/length.csv', - 'service/attribute/config/descriptor/bottom/silhouette.csv', - 'service/attribute/config/descriptor/bottom/opening_type.csv', - 'service/attribute/config/descriptor/bottom/design.csv'] + 'app/service/attribute/config/descriptor/bottom/subtype.csv', + 'app/service/attribute/config/descriptor/bottom/length.csv', + 'app/service/attribute/config/descriptor/bottom/silhouette.csv', + 'app/service/attribute/config/descriptor/bottom/opening_type.csv', + 'app/service/attribute/config/descriptor/bottom/design.csv'] bottom_model_list = [ 'attr_retrieve_B_subtype', @@ -35,14 +35,14 @@ bottom_model_list = [ 'attr_recong_B_optype', 'attr_retrieve_B_design'] -outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shape.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/outwear/collar.csv', - 'service/attribute/config/descriptor/outwear/design.csv', - 'service/attribute/config/descriptor/outwear/opening_type.csv', - 'service/attribute/config/descriptor/outwear/silhouette.csv', ] +outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/outwear/collar.csv', + 'app/service/attribute/config/descriptor/outwear/design.csv', + 'app/service/attribute/config/descriptor/outwear/opening_type.csv', + 'app/service/attribute/config/descriptor/outwear/silhouette.csv', ] outwear_model_list = ['attr_recong_O_length', 'attr_retrieve_O_sleeve_length', @@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length', 'attr_recong_O_optype', 'attr_retrieve_O_silhouette'] -dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_shape.csv', - # 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', - 'service/attribute/config/descriptor/dress/neckline.csv', - 'service/attribute/config/descriptor/dress/collar.csv', - 'service/attribute/config/descriptor/dress/design.csv', - 'service/attribute/config/descriptor/dress/silhouette.csv', - 'service/attribute/config/descriptor/dress/type.csv'] +dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_shape.csv', + # 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/dress/neckline.csv', + 'app/service/attribute/config/descriptor/dress/collar.csv', + 'app/service/attribute/config/descriptor/dress/design.csv', + 'app/service/attribute/config/descriptor/dress/silhouette.csv', + 'app/service/attribute/config/descriptor/dress/type.csv'] dress_model_list = [ # 'attr_recong_D_length', 'attr_retrieve_D_sleeve_length', diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index da71c16..ddcfd1c 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -11,12 +11,12 @@ from minio import Minio import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel +from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - logging.info("实例化完成") + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( @@ -97,9 +97,10 @@ class AttributeRecognition: return res def get_image(self, url): - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index 18ee043..fb997e9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -18,12 +18,13 @@ import torch from app.core.config import * from app.schemas.attribute_retrieve import CategoryRecognitionModel +from app.service.utils.oss_client import oss_get_image class CategoryRecognition: def __init__(self, request_data): self.attr_type = pd.read_csv(CATEGORY_PATH) - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) for sketch in request_data: @@ -51,9 +52,10 @@ class CategoryRecognition: def get_image(self, url): # Get data of an object. # Read data from response. - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/chat_robot/script/agents/__init__.py b/app/service/chat_robot/script/agents/__init__.py new file mode 100644 index 0000000..30c40f9 --- /dev/null +++ b/app/service/chat_robot/script/agents/__init__.py @@ -0,0 +1,7 @@ +from .agent_executor import CustomAgentExecutor +from .conversational_functions_agent import ConversationalFunctionsAgent + +__all__ = [ + "CustomAgentExecutor", + "ConversationalFunctionsAgent" +] diff --git a/app/service/chat_robot/script/agents/agent_executor.py b/app/service/chat_robot/script/agents/agent_executor.py new file mode 100644 index 0000000..cc69936 --- /dev/null +++ b/app/service/chat_robot/script/agents/agent_executor.py @@ -0,0 +1,132 @@ +import inspect +import json +import logging +from typing import Any, Dict, List, Optional, Union, Tuple + +from langchain.agents import AgentExecutor +from langchain.callbacks.manager import Callbacks, CallbackManager +from langchain.load.dump import dumpd +from langchain.schema import RUN_KEY, RunInfo +from langchain_core.agents import AgentAction, AgentFinish + + +class CustomAgentExecutor(AgentExecutor): + def __call__( + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, + session_key: str = "", + *, + tags: Optional[List[str]] = None, + include_run_info: bool = False, + ) -> Dict[str, Any]: + """Run the logic of this chain and add to output if desired. + + Args: + inputs: Dictionary of inputs, or single input if chain expects + only one param. + return_only_outputs: boolean for whether to return only outputs in the + response. If True, only new keys generated by this chain will be + returned. If False, both input keys and new keys generated by this + chain will be returned. Defaults to False. + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. + """ + inputs = self.prep_inputs(inputs, session_key) + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose, tags, self.tags + ) + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( + dumpd(self), + inputs, + ) + try: + outputs = ( + self._call(inputs, run_manager=run_manager) + if new_arg_supported + else self._call(inputs) + ) + except (KeyboardInterrupt, Exception) as e: + logging.exception(e) + run_manager.on_chain_error(e) + raise e + run_manager.on_chain_end(outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs, session_key + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs + + def prep_outputs( + self, + inputs: Dict[str, str], + outputs: Dict[str, str], + return_only_outputs: bool = False, + session_key: str = "" + ) -> Dict[str, str]: + """Validate and prep outputs.""" + self._validate_outputs(outputs) + if self.memory is not None and outputs['need_record']: + self.memory.save_context(inputs, outputs, session_key) + if return_only_outputs: + return outputs + else: + return {**inputs, **outputs} + + def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]: + """Validate and prep inputs.""" + if not isinstance(inputs, dict): + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + if len(_input_keys) != 1: + raise ValueError( + f"A single string input was passed in, but this chain expects " + f"multiple inputs ({_input_keys}). When a chain expects " + f"multiple inputs, please call it by passing in a dictionary, " + "eg `chain({'foo': 1, 'bar': 2})`" + ) + inputs = {list(_input_keys)[0]: inputs} + if self.memory is not None: + external_context = self.memory.load_memory_variables(inputs, session_key) + inputs = dict(inputs, **external_context) + self._validate_inputs(inputs) + return inputs + + def _get_tool_return( + self, next_step_output: Tuple[AgentAction, str] + ) -> Optional[AgentFinish]: + """Check if the tool is a returning tool.""" + agent_action, observation = next_step_output + name_to_tool_map = {tool.name: tool for tool in self.tools} + return_value_key = "output" + + if len(self.agent.return_values) > 0: + return_value_key = self.agent.return_values[0] + + try: + observation_list = json.loads(observation) + if agent_action.tool == "sql_db_query" and isinstance(observation_list, + list) and observation_list.__len__() != 0: + return AgentFinish( + {return_value_key: observation}, + "", + ) + except: + pass + + # Invalid tools won't be in the map, so we return False. + if agent_action.tool in name_to_tool_map: + if name_to_tool_map[agent_action.tool].return_direct: + return AgentFinish( + {return_value_key: observation}, + "", + ) + return None diff --git a/app/service/chat_robot/script/agents/conversational_functions_agent.py b/app/service/chat_robot/script/agents/conversational_functions_agent.py new file mode 100644 index 0000000..eb362a7 --- /dev/null +++ b/app/service/chat_robot/script/agents/conversational_functions_agent.py @@ -0,0 +1,198 @@ +import json +import re +from json import JSONDecodeError +from typing import List, Tuple, Any, Union +from dataclasses import dataclass + +from langchain.callbacks.manager import Callbacks +from langchain.agents import ( + OpenAIFunctionsAgent, +) +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseMessage, + OutputParserException +) +from langchain.schema.messages import ( + AIMessage, + FunctionMessage +) +from langchain.tools import BaseTool, StructuredTool +# from langchain.tools.convert_to_openai import FunctionDescription +from langchain.utils.openai_functions import FunctionDescription + + +@dataclass +class _FunctionsAgentAction(AgentAction): + """Add message_log to AgentAction class for the _FunctionAgentAction + """ + message_log: List[BaseMessage] + + def __init__( + self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any + ): + """Override init to support instantiation by position for backward compat.""" + super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) + + +def _convert_agent_action_to_messages( + agent_action: AgentAction, observation: str +) -> List[BaseMessage]: + """Convert an agents action to a message. + + This code is used to reconstruct the original AI message from the agents action. + + Args: + agent_action: Agent action to convert. + + Returns: + AIMessage that corresponds to the original tools invocation. + """ + if isinstance(agent_action, _FunctionsAgentAction): + return agent_action.message_log + [ + _create_function_message(agent_action, observation) + ] + else: + return [AIMessage(content=agent_action.log)] + + +def _create_function_message( + agent_action: AgentAction, observation: str +) -> FunctionMessage: + """Convert agents action and observation into a function message. + Args: + agent_action: the tools invocation request from the agents + observation: the result of the tools invocation + Returns: + FunctionMessage that corresponds to the original tools invocation + """ + if not isinstance(observation, str): + try: + content = json.dumps(observation, ensure_ascii=False) + except Exception: + content = str(observation) + else: + content = observation + return FunctionMessage( + name=agent_action.tool, + content=content, + ) + + +def _format_intermediate_steps( + intermediate_steps: List[Tuple[AgentAction, str]], +) -> List[BaseMessage]: + """Format intermediate steps. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + Returns: + list of messages to send to the LLM for the next prediction + """ + messages = [] + + for intermediate_step in intermediate_steps: + agent_action, observation = intermediate_step + messages.extend(_convert_agent_action_to_messages(agent_action, observation)) + + return messages + + +def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: + """Format tools into the OpenAI function API.""" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": tool.param_description if hasattr(tool, 'param_description') else "", + }, + }, + "required": ["query"], + } + + return { + "name": tool.name, + "description": tool.description, + "parameters": parameters, + } + + +def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: + if not isinstance(message, AIMessage): + raise TypeError(f"Expected an AI message but got {type(message)}") + + function_call = message.additional_kwargs.get("function_call", {}) + + if function_call: + function_call = message.additional_kwargs["function_call"] + function_name = function_call["name"] + try: + _tool_input = json.loads(function_call["arguments"]) + except JSONDecodeError: + raise OutputParserException( + f"Could not parse tools input: {function_call} because" + f"the `arguments` is not valid JSON." + ) + + if "query" in _tool_input: + tool_input = _tool_input["query"] + else: + tool_input = _tool_input + + return _FunctionsAgentAction( + tool=function_name, + tool_input=tool_input, + log=f"\nInvoking: `{function_name}` with `{tool_input}`\n", + message_log=[message] + ) + + # pattern = r'\((.*?)\)' + # matches = re.findall(pattern, message.content) + # result = [] + # + # for match in matches: + # result.append(match) + # + # if result: + # output = result + # else: + # output = message.content + + return AgentFinish(return_values={"output": message.content}, log=message.content) + + +class ConversationalFunctionsAgent(OpenAIFunctionsAgent): + @property + def functions(self) -> List[dict]: + return [dict(_format_tool_to_openai_function(t)) for t in self.tools] + + def plan(self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Decide how agents should move after receiving an input. The difference between + OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser + into it. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + **kwargs: Including user's input string + + Returns: + Action specifying what tools to use. + """ + agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) + messages: List[BaseMessage] = prompt.to_messages() + predicted_message = self.llm.predict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + agent_decision = _parse_ai_message(predicted_message) + return agent_decision diff --git a/app/service/chat_robot/script/callbacks/__init__.py b/app/service/chat_robot/script/callbacks/__init__.py new file mode 100644 index 0000000..8f644bd --- /dev/null +++ b/app/service/chat_robot/script/callbacks/__init__.py @@ -0,0 +1,6 @@ +from .openai_token_record_callback import OpenAITokenRecordCallbackHandler + + +__all__ = [ + 'OpenAITokenRecordCallbackHandler' +] diff --git a/app/service/chat_robot/script/callbacks/openai_token_record_callback.py b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py new file mode 100644 index 0000000..64ed7f4 --- /dev/null +++ b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py @@ -0,0 +1,46 @@ +"""Callback Handler that add on_chain_end function to record Token usage.""" +from typing import Any, Dict + +from langchain.callbacks import OpenAICallbackHandler +from langchain.schema import LLMResult +from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model + + +class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler): + need_record: bool = True + response_type: str = "string" + """Callback Handler that tracks OpenAI info and write to redis after agent finish""" + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + self.successful_requests += 1 + if "token_usage" not in response.llm_output: + return None + if "function_call" in response.generations[0][0].message.additional_kwargs: + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]: + self.need_record = False + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query": + self.response_type = "image" + token_usage = response.llm_output["token_usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + model_name = standardize_model_name(response.llm_output.get("model_name", "")) + if model_name in MODEL_COST_PER_1K_TOKENS: + completion_cost = get_openai_token_cost_for_model( + model_name, completion_tokens, is_completion=True + ) + prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens) + self.total_cost += prompt_cost + completion_cost + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None: + """Write token usage to redis.""" + outputs["total_tokens"] = self.total_tokens + outputs["total_cost"] = self.total_cost + outputs["prompt_tokens"] = self.prompt_tokens + outputs["completion_tokens"] = self.completion_tokens + outputs["need_record"] = self.need_record + outputs["response_type"] = self.response_type diff --git a/app/service/chat_robot/script/database.py b/app/service/chat_robot/script/database.py new file mode 100644 index 0000000..8a5dfdb --- /dev/null +++ b/app/service/chat_robot/script/database.py @@ -0,0 +1,79 @@ +from typing import Optional, List +import json + +from sqlalchemy import text +# from langchain import SQLDatabase +from langchain.utilities import SQLDatabase + + +class CustomDatabase(SQLDatabase): + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + # def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + connection = self._engine.connect() + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + # raise ValueError(f"table_names {missing_tables} not found in database") + return f"Table {','.join(missing_tables)} can not be found in the database" + all_table_names = table_names + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + ] + + tables = [] + for table in meta_tables: + table_name = table.name + column_names = table.columns.keys() + table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n" + for column_name in column_names: + if column_name not in ["ID", "img_name"]: + query = text(f"SELECT DISTINCT {column_name} FROM {table_name}") + result = connection.execute(query) + enum_values: List[str] = [row[0] for row in result.fetchall()] + column_info = f"{column_name}: {', '.join(enum_values)}\n" + table_info += column_info + + # table_info = f"Table: {table_name}\n" + # + # if self._sample_rows_in_table_info: + # table_info += f"{self._get_sample_rows(table)}\n" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def run(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + """ + with self._engine.begin() as connection: + if self._schema is not None: + if self.dialect == "snowflake": + connection.exec_driver_sql( + f"ALTER SESSION SET search_path='{self._schema}'" + ) + elif self.dialect == "bigquery": + connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") + else: + connection.exec_driver_sql(f"SET search_path TO {self._schema}") + cursor = connection.execute(text(command)) + if cursor.rowcount: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone() # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + + # Convert columns values to string to avoid issues with sqlalchmey + # trunacating text + if isinstance(result, list): + return json.dumps([r[0] for r in result]) + + return json.dumps([result[0]]) + return "" diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py new file mode 100644 index 0000000..1e64ca4 --- /dev/null +++ b/app/service/chat_robot/script/main.py @@ -0,0 +1,115 @@ +import json +import logging + +from langchain.agents import Tool +from langchain.callbacks import FileCallbackHandler +from langchain.chat_models import ChatOpenAI +from langchain.llms.openai import OpenAI +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder +from langchain.schema import SystemMessage, AIMessage +from langchain.utilities import SerpAPIWrapper +from loguru import logger + +from app.core.config import * +from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent +from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler +from app.service.chat_robot.script.database import CustomDatabase +from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX +from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) +from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" +# log callbacks +logfile = "logs/chat_debug.log" +logger.add(logfile, colorize=True, enqueue=True) +log_handler = FileCallbackHandler(logfile) + +# Initiate our LLM 'gpt-3.5-turbo' +llm = ChatOpenAI(temperature=0.1, + openai_api_key=OPENAI_API_KEY, + # callbacks=[OpenAICallbackHandler()] + ) + +search = SerpAPIWrapper() +db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3', + include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress', + 'female_outwear', 'male_bottom', 'male_top', 'male_outwear'], + engine_args={"pool_recycle": 7200}) +tools = [ + Tool( + name="internet_search", + description="Can be used to perform Internet searches", + func=search.run + ), + QuerySQLDataBaseTool(db=db, return_direct=False), + InfoSQLDatabaseTool(db=db), + ListSQLDatabaseTool(db=db), + QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)), + Tool( + name="tutorial_tool", + description="Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string", + func=CustomTutorialTool(), + return_direct=True + ) +] + +messages = [ + SystemMessage(content=FASHION_CHAT_BOT_PREFIX), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template( + "{input} " + "Question from a {gender}." + ), + AIMessage(content=TOOLS_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), +] + +prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages) +agent = ConversationalFunctionsAgent( + llm=llm, + tools=tools, + prompt=prompt +) + +memory = UserConversationBufferWindowMemory.from_redis( + return_messages=True, k=2, input_key='input', output_key='output' +) +agent_executor = CustomAgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + verbose=True, + memory=memory, +) + + +def chat(post_data): + user_id = post_data.user_id + session_id = post_data.session_id + input_message = post_data.message + gender = post_data.gender + + final_outputs = agent_executor( + {"input": input_message, "gender": gender}, + callbacks=[OpenAITokenRecordCallbackHandler(), log_handler], + session_key=f"buffer:{user_id}:{session_id}", + ) + api_response = { + 'user_id': user_id, + 'session_id': session_id, + # 'message_id': message_id, + # 'create_time': created_time, + 'input': final_outputs['input'], + # 'conversion': messages, + 'output': final_outputs['output'], + # 'gpt_response_time': gpt_response_time, + 'total_tokens': final_outputs['total_tokens'], + 'total_cost': final_outputs['total_cost'], + 'prompt_tokens': final_outputs['prompt_tokens'], + 'completion_tokens': final_outputs['completion_tokens'], + 'response_type': final_outputs['response_type'] + } + logging.info(json.dumps(api_response)) + return api_response diff --git a/app/service/chat_robot/script/memory/__init__.py b/app/service/chat_robot/script/memory/__init__.py new file mode 100644 index 0000000..9586157 --- /dev/null +++ b/app/service/chat_robot/script/memory/__init__.py @@ -0,0 +1,3 @@ +from .user_buffer_window import UserConversationBufferWindowMemory + +__all__ = ['UserConversationBufferWindowMemory'] diff --git a/app/service/chat_robot/script/memory/user_buffer_window.py b/app/service/chat_robot/script/memory/user_buffer_window.py new file mode 100644 index 0000000..9fbc2d6 --- /dev/null +++ b/app/service/chat_robot/script/memory/user_buffer_window.py @@ -0,0 +1,93 @@ +import logging +from typing import Any, Dict, List, Tuple +import json + +import redis +from redis import Redis +from langchain.memory import RedisChatMessageHistory +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage +from langchain.schema.messages import _message_to_dict, messages_from_dict +from langchain.memory.utils import get_prompt_input_key + +from app.core.config import * + + +class UserConversationBufferWindowMemory(BaseChatMemory): + """Buffer for storing conversation memory.""" + + redis_client: Redis + human_prefix: str = "Human" + ai_prefix: str = "AI" + memory_key: str = "history" #: :meta private: + k: int = 5 + + @classmethod + def from_redis( + cls, + host: str = REDIS_HOST, + port: int = REDIS_PORT, + db: int = 3, + **kwargs + ): + redis_client = Redis(host=host, port=port, db=db) + try: + response = redis_client.ping() + if response: + print("Connect to redis server successfully.") + logging.info("Connect to redis server successfully.") + else: + print("Fail to connect to redis server") + logging.info("Fail to connect to redis server") + except redis.RedisError as e: + logging.info(f"Error occurs when connecting to redis server: {str(e)}") + return cls(redis_client=redis_client, **kwargs) + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]: + """Return history buffer.""" + _items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else [] + items = [json.loads(m.decode("utf-8")) for m in _items[::-1]] + buffer = messages_from_dict(items) + if not self.return_messages: + buffer = get_buffer_string( + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + return {self.memory_key: buffer} + + def _get_input_output( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> Tuple[str, str]: + if self.input_key is None: + prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) + else: + prompt_input_key = self.input_key + if self.output_key is None: + if len(outputs) != 1: + raise ValueError(f"One output key expected, got {outputs.keys()}") + output_key = list(outputs.keys())[0] + else: + output_key = self.output_key + return inputs[prompt_input_key], outputs[output_key] + + def add_message(self, key: str, message: BaseMessage) -> None: + self.redis_client.lpush(key, json.dumps(_message_to_dict(message))) + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None: + """Save context from this conversation to buffer.""" + input_str, output_str = self._get_input_output(inputs, outputs) + self.add_message(key, HumanMessage(content=input_str)) + self.add_message(key, AIMessage(content=output_str)) + + # def clear(self, key) -> None: + # """Clear memory contents.""" + # self.redis_client.delete(key) diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py new file mode 100644 index 0000000..a88044d --- /dev/null +++ b/app/service/chat_robot/script/prompt.py @@ -0,0 +1,52 @@ +FASHION_CHAT_BOT_PREFIX = """ +You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +Remember your answer should be very precise and the final output answer should not exceed 20 words. + +You may encounter the following types of questions: +1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +If the question does not seem related to the database, just return "I don't know" as the answer. + +2) If the query related to current events, you should use internet_search to seek help from the internet. + +3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. + +Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +""" + +TOOL_SELECT_SUFFIX = """ +Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly. +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +The use of online resources should be limited to inquiry pertaining to current subjects. +""" + +SQL_FUNCTIONS_SUFFIX = """ +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +""" + +INTERNET_SEARCH_SUFFIX = """ +If the question should be answered using internet search tools, I should seek help from the internet. +""" + +ANSWER_FORMAT_SUFFIX = """ +My final answer are limited to 20 words and be as much precise as possible. +""" + +TOOLS_FUNCTIONS_SUFFIX = ( + "If the input involves clothing queries," + "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." + "All SQL statements must use 'ORDER BY RAND()', for example:" + "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" + "If the input does not involve clothing queries, " + "I should engage in conversation as an assistant or search from internet with internet_search tool." + "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +) + +TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." diff --git a/app/service/chat_robot/script/tools/__init__.py b/app/service/chat_robot/script/tools/__init__.py new file mode 100644 index 0000000..4a40a33 --- /dev/null +++ b/app/service/chat_robot/script/tools/__init__.py @@ -0,0 +1,10 @@ +from .sql_tools import ( + QuerySQLDataBaseTool, + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool +) + +__all__ = [ + "QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool" +] diff --git a/app/service/chat_robot/script/tools/sql_tools.py b/app/service/chat_robot/script/tools/sql_tools.py new file mode 100644 index 0000000..92b8003 --- /dev/null +++ b/app/service/chat_robot/script/tools/sql_tools.py @@ -0,0 +1,183 @@ +# flake8: noqa +"""Tools for interacting with a SQL database.""" +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.chains.llm import LLMChain +from langchain.prompts import PromptTemplate +# from langchain.sql_database import SQLDatabase +from langchain.utilities import SQLDatabase +from langchain.tools.base import BaseTool +from langchain.tools.sql_database.prompt import QUERY_CHECKER + + +class BaseSQLDatabaseTool(BaseModel): + """Base tools for interacting with a SQL database.""" + + db: SQLDatabase = Field(exclude=True) + param_description: str = "" + + # Override BaseTool.Config to appease mypy + # See https://github.com/pydantic/pydantic/issues/4173 + class Config(BaseTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.forbid + + +class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for querying a SQL database.""" + + name = "sql_db_query" + # description = """ + # Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables. + # This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database. + # You should always use ‘order by rand()’ to randomly select data. + # If the query is not correct, an error message will be returned. + # If an error is returned, rewrite the query, check the query, and try again. + # Always limit your query to at most 4 results. + # Never query for all the columns from a specific table, only ask for the relevant columns given the question. + # You MUST double check your query before executing it. + # DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + # """ + + description: str = ( + "The input of this tool is a detailed and correct SQL select query statement, " + "and the output is the result of the database, and it can only return up to 4 results." + "If the query is not correct, an error message will be returned." + "If an error is returned, rewrite the query, check the query, and try again." + "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," + "use sql_db_schema to query the correct table fields." + + "Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() " + "LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' " + "order by rand() LIMIT 2'" + ) + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + result = self.db.run_no_throw(query) + return result + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("QuerySqlDbTool does not support async") + + +class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting metadata about a SQL database.""" + + name = "sql_db_schema" + # description = """ + # The database contains information of lots of fashion items, such as item name, their fashion attributes. + # There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear. + # Find the most relevant tables with the query, and output the schema of these tables. + # """ + + description: str = ( + "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables." + "There are eight tables covering eight fashion categories: female_top, female_pants, female_dress," + "female_skirt, female_outwear, male_bottom, male_top, and male_outwear." + + "Example Input: 'female_outwear, male_top'" + ) + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw(table_names.split(", ")) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SchemaSqlDbTool does not support async") + + +class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting tables names.""" + + name = "sql_db_list_tables" + description = "Input is an empty string, output is a comma separated list of tables in the database." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("ListTablesSqlDbTool does not support async") + + +class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: LLMChain = Field(init=False) + name = "sql_db_query_checker" + description = ( + "Use this tools to double check if your query is correct before executing it." + "Always use this tools before executing a query with sql_db_query!" + ) + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, + input_variables=["query", "dialect"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: + # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: + raise ValueError( + "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict(query=query, dialect=self.db.dialect) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict(query=query, dialect=self.db.dialect) diff --git a/app/service/chat_robot/script/tools/tutorial_tool.py b/app/service/chat_robot/script/tools/tutorial_tool.py new file mode 100644 index 0000000..c08eb9d --- /dev/null +++ b/app/service/chat_robot/script/tools/tutorial_tool.py @@ -0,0 +1,19 @@ +from typing import Any + +from langchain.tools.base import BaseTool + +from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN + + +# 处理系统引导教程相关的输入 +class CustomTutorialTool(BaseTool): + name = "tutorial_tool" + + description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string") + + def _run(self, tool_input, **kwargs: Any) -> str: + return TUTORIAL_TOOL_RETURN + + async def _arun(self, tool_input, **kwargs: Any) -> str: + raise NotImplementedError("CustomTutorialTool does not support async") diff --git a/app/service/chat_robot/script/utils/__init__.py b/app/service/chat_robot/script/utils/__init__.py new file mode 100644 index 0000000..92a2f16 --- /dev/null +++ b/app/service/chat_robot/script/utils/__init__.py @@ -0,0 +1 @@ +from .logger import Logger diff --git a/app/service/chat_robot/script/utils/logger.py b/app/service/chat_robot/script/utils/logger.py new file mode 100644 index 0000000..cb52c18 --- /dev/null +++ b/app/service/chat_robot/script/utils/logger.py @@ -0,0 +1,26 @@ +import logging +from logging import handlers + + +class Logger(object): + level_relations = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'crit': logging.CRITICAL + } + + def __init__(self, filename, level='info', when='D', backCount=3, + fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'): + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) # set log format + self.logger.setLevel(self.level_relations.get(level)) # set log level + sh = logging.StreamHandler() # output to terminal + sh.setFormatter(format_str) # set format for terminal log + th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount, + encoding='utf-8') # log into file + + th.setFormatter(format_str) # set format for file log + self.logger.addHandler(sh) # output to terminal + self.logger.addHandler(th) # output to file diff --git a/app/service/design/core/__init__.py b/app/service/design/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/design/core/layer.py b/app/service/design/core/layer.py new file mode 100644 index 0000000..0628851 --- /dev/null +++ b/app/service/design/core/layer.py @@ -0,0 +1,116 @@ +import logging + +import numpy as np +import cv2 +from matplotlib import pyplot as plt + +from PIL import Image + + +def show(img, win_name="temp"): + cv2.imshow(win_name, img) + cv2.waitKey(0) + + +def crop(img): + mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2) + img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340] + return img_roi + + +class Layer(object): + def __init__(self): + self._layer = [] + + @property + def layer(self): + return self._layer + + def insert(self, layer_instance): + if layer_instance['name'] == 'body': + self._body = layer_instance + self._layer.append(layer_instance) + + def sort(self, priority): + self._layer.sort(key=lambda x: priority[x['name']]) + + # def merge(self, cfg): + # """ + # opencv shape order (height, width, channel) + # image coordinate system: + # |------------->x (width) + # | + # | + # | + # y (height) + # Returns: + # + # + # """ + # base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0)) + # for layer in self._layer: + # y, x = layer['position'] + # base_image.paste(layer['image'], (x, y), layer['image']) + # # base_image.show() + # + # for x in self._layer: + # if np.all(x['mask'] == 0): + # continue + # # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask) + # roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x) + # temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2) + # tmp1 = (roi * (temp_bg / 255)).astype(np.uint8) + # temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2) + # tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8) + # + # roi[:] = cv2.add(tmp1, tmp2) + # # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA), + # # win_name=x.get('name')) + # # crop image and get the central part + # if cfg.get('basic')['self_template'] == False: + # dst_roi = crop(dst) + # else: + # dst_roi = dst + # return dst_roi, signal + # + # @staticmethod + # def get_roi(dst, image): + # signal = False + # dst_y, dst_x = dst.shape[:2] + # roi_height, roi_width = image['mask'].shape + # roi_y0, roi_x0 = image['position'] + # + # if roi_y0 < 0: + # roi_yin = 0 + # mask_yin = -roi_y0 + # signal = True + # else: + # roi_yin = roi_y0 + # mask_yin = 0 + # if roi_y0 + roi_height > dst_y: + # roi_yout = dst_y + # mask_yout = dst_y - roi_y0 + # signal = True + # else: + # roi_yout = roi_height + roi_y0 + # mask_yout = roi_height + # # x part + # if roi_x0 < 0: + # roi_xin = 0 + # mask_xin = -roi_x0 + # signal = True + # else: + # roi_xin = roi_x0 + # mask_xin = 0 + # if roi_x0 + roi_width > dst_x: + # roi_xout = dst_x + # mask_xout = dst_x - roi_x0 + # signal = True + # else: + # roi_xout = roi_width + roi_x0 + # mask_xout = roi_width + # + # roi = dst[roi_yin: roi_yout, roi_xin: roi_xout] + # roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout] + # roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout] + # return roi, roi_mask, roi_image, signal diff --git a/app/service/design/core/priority.py b/app/service/design/core/priority.py new file mode 100644 index 0000000..dc111ea --- /dev/null +++ b/app/service/design/core/priority.py @@ -0,0 +1,45 @@ +class Priority(object): + """Item layer priority levels. + """ + + def __init__(self, item_list): + self._priority = dict( + earring_front=99, + bag_front=98, + hairstyle_front=97, + outwear_front=20, + bottoms_front=19, + dress_front=18, + blouse_front=17, + skirt_front=16, + trousers_front=15, + tops_front=14, + shoes_right=1, + shoes_left=1, + body=0, + tops_back=-14, + trousers_back=-15, + skirt_back=-16, + blouse_back=-17, + dress_back=-18, + bottoms_back=-19, + outwear_back=-20, + hairstyle_back=-97, + bag_back=-98, + earring_back=-99, + ) + self.clothing_start_num = 10 + if not isinstance(item_list, list): + raise ValueError('item_list must be a list!') + for cate in item_list: + cate = cate.lower() + if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + raise ValueError(f'Item type error. Cannot recognize {cate}') + for i, cate in enumerate(item_list): + cate = cate.lower() + self._priority[f'{cate}_front'] = self.clothing_start_num - i + self._priority[f'{cate}_back'] = -(self.clothing_start_num - i) + + @property + def priority(self): + return self._priority diff --git a/app/service/design/fastapi_request.json b/app/service/design/fastapi_request.json new file mode 100644 index 0000000..8c27a56 --- /dev/null +++ b/app/service/design/fastapi_request.json @@ -0,0 +1,771 @@ +{ + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493827, + "color": "127 61 21", + "elementId": 493827, + "icon": "none", + "image_id": 110201, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "27 25 23", + "icon": "none", + "image_id": 110202, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/0916000602.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "businessId": 493825, + "color": "229 214 200", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "businessId": 493824, + "color": "76 124 124", + "elementId": 493824, + "icon": "none", + "image_id": 104522, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "229 214 200", + "icon": "none", + "image_id": 110203, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0825001576.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "76 124 124", + "icon": "none", + "image_id": 96071, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/903000097.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "209 125 29", + "icon": "none", + "image_id": 93798, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493824, + "color": "209 125 29", + "elementId": 493824, + "icon": "none", + "image_id": 104522, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 110204, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902000457.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 79259, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/826000094.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "127 61 21", + "icon": "none", + "image_id": 96038, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/dress/0902003549.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493822, + "color": "127 61 21", + "elementId": 493822, + "icon": "none", + "image_id": 62309, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 493825, + "color": "118 123 115", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493826, + "color": "127 61 21", + "elementId": 493826, + "icon": "none", + "image_id": 107105, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 79114, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/903000169.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "229 214 200", + "icon": "none", + "image_id": 90573, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0628000541.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "229 214 200", + "icon": "none", + "image_id": 110205, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0916000217.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 493825, + "color": "209 125 29", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + } + ], + "process_id": "6878547032381675" +} \ No newline at end of file diff --git a/app/service/design/items/__init__.py b/app/service/design/items/__init__.py new file mode 100644 index 0000000..23f35bf --- /dev/null +++ b/app/service/design/items/__init__.py @@ -0,0 +1,16 @@ +from .builder import ITEMS, build_item +from .clothing import Clothing # 4.0 sec +from .body import Body +from .top import Top, Blouse, Outwear, Dress +from .bottom import Bottom, Trousers, Skirt +from .shoes import Shoes +from .bag import Bag +from .accessories import Hairstyle, Earring + +__all__ = [ + 'ITEMS', 'build_item', + 'Clothing', 'Body', + 'Top', 'Blouse', 'Outwear', 'Dress', + 'Bottom', 'Trousers', 'Skirt', + 'Shoes', 'Bag', 'Hairstyle', 'Earring' +] diff --git a/app/service/design/items/accessories.py b/app/service/design/items/accessories.py new file mode 100644 index 0000000..5cb5796 --- /dev/null +++ b/app/service/design/items/accessories.py @@ -0,0 +1,59 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Hairstyle(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Hairstyle, self).__init__(**kwargs) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point): + """ + align up + Args: + keypoint_type: string, "head_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + side_indicator = f'{keypoint_type}_up' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + # logging.info(clothes_point[side_indicator]) + + start_point = ( + int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)), + int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale)) + ) + return start_point + + +@ITEMS.register_module() +class Earring(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Earring, self).__init__(**kwargs) diff --git a/app/service/design/items/bag.py b/app/service/design/items/bag.py new file mode 100644 index 0000000..12b4c68 --- /dev/null +++ b/app/service/design/items/bag.py @@ -0,0 +1,45 @@ +import random + +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Bag(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Bag, self).__init__(**kwargs) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point): + """ + align left + Args: + keypoint_type: string, "hand_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (y', x') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + location = random.choice(seq=['left', 'right']) + if location == 'left': + side_indicator = f'{keypoint_type}_left' + else: + side_indicator = f'{keypoint_type}_right' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale), + body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale)) + return start_point diff --git a/app/service/design/items/body.py b/app/service/design/items/body.py new file mode 100644 index 0000000..c336ae9 --- /dev/null +++ b/app/service/design/items/body.py @@ -0,0 +1,36 @@ +import cv2 + +from .builder import ITEMS +from .pipelines import Compose + + +@ITEMS.register_module() +class Body(object): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']), + # dict(type='ImageShow', key=['body_image', "body_mask"]) + ] + self.pipeline = Compose(pipeline) + self.result = dict() + + def process(self): + self.pipeline(self.result) + pass + + def organize(self, layer): + body_layer = dict(priority=0, + name=type(self).__name__.lower(), + image=self.result['body_image'], + image_url=self.result['image_url'], + mask_image=None, + mask_url=None, + sacle=1, + # mask=self.result['body_mask'], + position=(0, 0)) + layer.insert(body_layer) + + @staticmethod + def show(img): + cv2.imshow('', img) + cv2.waitKey(0) diff --git a/app/service/design/items/bottom.py b/app/service/design/items/bottom.py new file mode 100644 index 0000000..eb575fb --- /dev/null +++ b/app/service/design/items/bottom.py @@ -0,0 +1,38 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Bottom(Clothing): + def __init__(self, pipeline, **kwargs): + if pipeline is None: + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting', painting_flag=True), + dict(type='PrintPainting', print_flag=True), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']), + ] + kwargs.update(pipeline=pipeline) + super(Bottom, self).__init__(**kwargs) + + +@ITEMS.register_module() +class Trousers(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Trousers, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Skirt(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Skirt, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Bottoms(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Bottoms, self).__init__(pipeline, **kwargs) diff --git a/app/service/design/items/builder.py b/app/service/design/items/builder.py new file mode 100644 index 0000000..26e04f1 --- /dev/null +++ b/app/service/design/items/builder.py @@ -0,0 +1,9 @@ +from mmcv.utils import Registry, build_from_cfg + +ITEMS = Registry('item') +PIPELINES = Registry('pipeline') + + +def build_item(cfg, default_args=None): + item = build_from_cfg(cfg, ITEMS, default_args) + return item diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py new file mode 100644 index 0000000..f9f9561 --- /dev/null +++ b/app/service/design/items/clothing.py @@ -0,0 +1,99 @@ +import cv2 + +from app.core.config import PRIORITY_DICT +from .builder import ITEMS +from .pipelines import Compose + + +@ITEMS.register_module() +class Clothing(object): + def __init__(self, pipeline, **kwargs): + self.pipeline = Compose(pipeline) + self.result = dict(name=type(self).__name__.lower(), **kwargs) + + def process(self): + self.pipeline(self.result) + + def apply_scale(self, img): + scale = self.result['scale'] + height, width = img.shape[0: 2] + if len(img.shape) > 2: + height, width = img.shape[0: 2] + scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA) + return scaled_img + + def organize(self, layer): + start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"]) + + front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None), + name=f'{type(self).__name__.lower()}_front', + image=self.result["front_image"], + # mask_image=self.result['front_mask_image'], + image_url=self.result['front_image_url'], + mask_url=self.result['front_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=start_point, + resize_scale=self.result["resize_scale"], + mask=cv2.resize(self.result['mask'], self.result["front_image"].size), + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] + + ) + layer.insert(front_layer) + + back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None), + name=f'{type(self).__name__.lower()}_back', + image=self.result["back_image"], + # mask_image=self.result['back_mask_image'], + image_url=self.result['back_image_url'], + mask_url=self.result['back_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=start_point, + resize_scale=self.result["resize_scale"], + mask=cv2.resize(self.result['mask'], self.result["front_image"].size), + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] + ) + layer.insert(back_layer) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): + """ + Align left + Args: + keypoint_type: string, "waistband" | "shoulder" | "ear_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + offset + y' = x_body - x1 * scale + offset + + """ + + side_indicator = f'{keypoint_type}_left' + + # if keypoint_type == "ear_point": + # start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale), + # body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale)) + # else: + # start_point = ( + # int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y + # int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x + # ) + + # milvus_DB_keypoint_cache: + start_point = ( + int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y + int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x + ) + # start_point = ( + # int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y + # int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x + # ) + + return start_point diff --git a/app/service/design/items/pipelines/__init__.py b/app/service/design/items/pipelines/__init__.py new file mode 100644 index 0000000..9abb09c --- /dev/null +++ b/app/service/design/items/pipelines/__init__.py @@ -0,0 +1,19 @@ +from .compose import Compose +from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow +from .keypoints import KeypointDetection +from .segmentation import Segmentation +from .painting import Painting, PrintPainting +from .scale import Scaling +from .contour_detection import ContourDetection +from .split import Split + +__all__ = [ + 'Compose', + 'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow', + 'KeypointDetection', + 'Segmentation', + 'Painting', 'PrintPainting', + 'Scaling', + 'ContourDetection', + 'split', +] diff --git a/app/service/design/items/pipelines/compose.py b/app/service/design/items/pipelines/compose.py new file mode 100644 index 0000000..daf6977 --- /dev/null +++ b/app/service/design/items/pipelines/compose.py @@ -0,0 +1,36 @@ +import collections + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data diff --git a/app/service/design/items/pipelines/contour_detection.py b/app/service/design/items/pipelines/contour_detection.py new file mode 100644 index 0000000..018dbca --- /dev/null +++ b/app/service/design/items/pipelines/contour_detection.py @@ -0,0 +1,58 @@ +import cv2 +import numpy as np + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class ContourDetection(object): + def __init__(self): + # logging.info("ContourDetection run ") + pass + + # @ RunTime + def __call__(self, result): + # shoe diff + if result['name'] == 'shoes': + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + for i in range(2): + Max_contour = Contour[i] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + else: + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + if len(Contour): + Max_contour = Contour[0] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + else: + Mask = np.ones(result['image'].shape[:2], np.uint8) * 255 + # TODO 修复部分图片出现透明的情况 下版本上线 + # img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + # ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY) + # Mask = cv2.bitwise_not(Mask) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + + return result + + @staticmethod + def get_contours(image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + Edge = cv2.Canny(gray, 10, 150) + kernel = np.ones((5, 5), np.uint8) + Edge = cv2.dilate(Edge, kernel=kernel, iterations=1) + Edge = cv2.erode(Edge, kernel=kernel, iterations=1) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contour = sorted(Contour, key=cv2.contourArea, reverse=True) + return Contour diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py new file mode 100644 index 0000000..1f53ced --- /dev/null +++ b/app/service/design/items/pipelines/keypoints.py @@ -0,0 +1,139 @@ +import logging +import time + +import numpy as np +from pymilvus import MilvusClient + +from app.core.config import * +from ..builder import PIPELINES +from ...utils.design_ensemble import get_keypoint_result + + +@PIPELINES.register_module() +class KeypointDetection(object): + """ + path here: abstract path + """ + + # def __init__(self): + # self.client = MilvusClient( + # uri="http://10.1.1.240:19530", + # token="root:Milvus", + # db_name=MILVUS_ALIAS + # ) + + # def __del__(self): + # start_time = time.time() + # self.client.close() + # print(f"client close time : {time.time() - start_time}") + + # @ RunTime + def __call__(self, result): + # logging.info("KeypointDetection run ") + if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 + # result['clothes_keypoint'] = self.infer_keypoint_result(result) + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # keypoint_cache = search_keypoint_cache(result["image_id"], site) + + keypoint_cache = self.keypoint_cache(result, site) + # 取消向量查询 直接过模型推理 + # keypoint_cache = False + + if keypoint_cache is False: + keypoint_infer_result, site = self.infer_keypoint_result(result) + result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) + else: + result['clothes_keypoint'] = keypoint_cache + return result + + @staticmethod + def infer_keypoint_result(result): + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + start_time = time.time() + keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果 + # logging.info(f"infer keypoint time : {time.time() - start_time}") + return keypoint_infer_result, site + + @staticmethod + # @ RunTime + def save_keypoint_cache(keypoint_id, cache, site): + if site == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, cache.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([cache.flatten(), zeros]) + # 取消向量保存 直接拿结果 + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": site, + "keypoint_vector": result.tolist() + } + ] + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + # start_time = time.time() + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + client.close() + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(keypoint_id, infer_result, search_result, site): + if site == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": "all", + "keypoint_vector": result.tolist() + } + ] + + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.upsert(data) + client.upsert( + collection_name=MILVUS_TABLE_KEYPOINT, + data=data + ) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + # @ RunTime + def keypoint_cache(self, result, site): + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + keypoint_id = result['image_id'] + res = client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != site: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) + except Exception as e: + logging.info(f"search keypoint cache milvus error {e}") + return False diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py new file mode 100644 index 0000000..d792646 --- /dev/null +++ b/app/service/design/items/pipelines/loading.py @@ -0,0 +1,130 @@ +import cv2 + +from app.service.utils.oss_client import oss_get_image +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + def __init__(self, path, color=None, print_dict=None): + self.path = path + self.color = color + self.print_dict = print_dict + # self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + def __call__(self, result): + result['image'], result['pre_mask'] = self.read_image(self.path) + result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + result['keypoint'] = self.get_keypoint(result['name']) + result['path'] = self.path + result['img_shape'] = result['image'].shape + result['ori_shape'] = result['image'].shape + result['color'] = self.color if self.color is not None else None + result['print_dict'] = self.print_dict + return result + + @staticmethod + def get_keypoint(name): + if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': + keypoint = 'shoulder' + elif name == 'trousers' or name == 'skirt' or name == 'bottoms': + keypoint = 'waistband' + elif name == 'bag': + keypoint = 'hand_point' + elif name == 'shoes': + keypoint = 'toe' + elif name == 'hairstyle': + keypoint = 'head_point' + elif name == 'earring': + keypoint = 'ear_point' + else: + raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " + f"bag, shoes, hairstyle, earring.") + return keypoint + + @staticmethod + def read_image(image_path): + image_mask = None + # file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + + image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + if image.shape[2] == 4: # 如果是四通道 mask + image_mask = image[:, :, 3] + image = image[:, :, :3] + return image, image_mask + + +@PIPELINES.register_module() +class LoadBodyImageFromFile(object): + def __init__(self, body_path): + self.body_path = body_path + # self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png") + + # @ RunTime + def __call__(self, result): + result["image_url"] = result['body_path'] = self.body_path + result["name"] = "mannequin" + # if not result['image_url'].lower().endswith(".png"): + # bucket = self.body_path.split("/", 1)[0] + # object_name = self.body_path.split("/", 1)[1] + # new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + # image = self.minioClient.get_object(bucket, object_name) + # image = Image.open(io.BytesIO(image.data)) + # image = image.convert("RGBA") + # data = image.getdata() + # # + # new_data = [] + # for item in data: + # if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + # new_data.append((255, 255, 255, 0)) + # else: + # new_data.append(item) + # image.putdata(new_data) + # image_data = io.BytesIO() + # image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # self.body_path = image_path + # result["image_url"] = result['body_path'] = self.body_path + # response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) + # put_image_time = time.time() + # result['body_image'] = Image.open(io.BytesIO(response.read())) + result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL") + # logging.info(f"Image.open time is : {time.time() - put_image_time}") + return result + + +@PIPELINES.register_module() +class ImageShow(object): + def __init__(self, key): + self.key = key + + # @ RunTime + def __call__(self, result): + import matplotlib.pyplot as plt + if isinstance(self.key, list): + for key in self.key: + plt.imshow(result[key]) + plt.title(key) + plt.show() + elif isinstance(self.key, str): + img = self._resize_img(result[self.key]) + cv2.imshow(self.key, img) + cv2.waitKey(0) + else: + raise TypeError(f'key should be string but got type {type(self.key)}.') + return result + + @staticmethod + def _resize_img(img): + shape = img.shape + if shape[0] > 400 or shape[1] > 400: + ratio = min(400 / shape[0], 400 / shape[1]) + img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0]))) + return img diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py new file mode 100644 index 0000000..224e753 --- /dev/null +++ b/app/service/design/items/pipelines/painting.py @@ -0,0 +1,611 @@ +import random + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.oss_client import oss_get_image +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Painting(object): + def __init__(self, painting_flag=True): + self.painting_flag = painting_flag + + # @ RunTime + def __call__(self, result): + if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none': + dim_image_h, dim_image_w = result['image'].shape[0:2] + if "gradient" in result.keys() and result['gradient'] != "": + bucket_name = result['gradient'].split('/')[0] + object_name = result['gradient'][result['gradient'].find('/') + 1:] + pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + else: + pattern = self.get_pattern(result['color']) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + result['alpha'] = 100 / 255.0 + else: + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + get_image_fir = result['image'] * (closed_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + return result + + @staticmethod + def get_gradient(bucket_name, object_name): + # image_data = minio_client.get_object(bucket_name, object_name) + # image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] + + # 从数据流中读取图像 + # image_bytes = image_data.read() + + # 将图像数据转换为numpy数组 + # image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) + + # 使用OpenCV解码图像数组 + # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + return image + + @staticmethod + def crop_image(image, image_size_h, image_size_w): + x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6) + y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6) + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_pattern(single_color): + if single_color is None: + raise False + R, G, B = single_color.split(' ') + pattern = np.zeros([1, 1, 3], np.uint8) + pattern[0, 0, 0] = int(B) + pattern[0, 0, 1] = int(G) + pattern[0, 0, 2] = int(R) + return pattern + + +@PIPELINES.register_module() +class PrintPainting(object): + def __init__(self, print_flag=True): + self.print_flag = print_flag + + # @ RunTime + def __call__(self, result): + single_print = result['print']['single'] + overall_print = result['print']['overall'] + element_print = result['print']['element'] + result['single_image'] = None + result['print_image'] = None + if overall_print['print_path_list']: + painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} + result['print_image'] = result['pattern_image'] + if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) + + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] + + if single_print['print_path_list']: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(single_print['print_path_list'])): + image, image_mode = self.read_image(single_print['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + + if element_print['element_path_list']: + print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(element_print['element_path_list'])): + image, image_mode = self.read_image(element_print['element_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + print(1) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + # TODO element 丢失信息 + three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) + img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) + # mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + # gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + # img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result + + @staticmethod + def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): + temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8) + + temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + + img2gray = cv2.cvtColor(print_background, cv2.COLOR_BGR2GRAY) + + ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY) + + mask_inv = cv2.bitwise_not(mask_) + + img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_) + + img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_inv) + + print_background = img1_bg + img2_fg + + return print_background + + def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): + if print_trigger: + print_ = self.get_print(print_dict) + painting_dict['Trigger'] = not is_single + painting_dict['location'] = print_['location'] + single_mask_inv_print = self.get_mask_inv(print_['image']) + dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) + dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) + if not is_single: + self.random_seed = random.randint(0, 1000) + # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 + if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern + return painting_dict + + def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + tile = None + if not trigger: + tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + else: + resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + if len(pattern.shape) == 2: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4)) + if len(pattern.shape) == 3: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1)) + tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape) + return tile + + def get_mask_inv(self, print_): + if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255: + bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + mask_inv = cv2.inRange(print_tile, lower, upper) + return mask_inv + else: + # bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + # bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + # bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + # bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + # lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + # upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY) + + # mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY) + mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8) + return mask_inv + + @staticmethod + def printpaint(result, painting_dict, print_=False): + + if print_ and painting_dict['Trigger']: + print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) + img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) + else: + print_mask = result['mask'] + img_fg = result['final_image'] + if print_ and not painting_dict['Trigger']: + index_ = None + try: + index_ = len(painting_dict['location']) + except: + assert f'there must be parameter of location if choose IfSingle' + + for i in range(index_): + start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) + + length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) + length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) + + change_region = img_fg[start_h: length_h, start_w: length_w, :] + # problem in change_mask + change_mask = print_mask[start_h: length_h, start_w: length_w] + # get real part into change mask + _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) + mask = cv2.bitwise_not(painting_dict['mask_inv_print']) + img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region + + clothes_mask_print = cv2.bitwise_not(print_mask) + + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + print_image = cv2.add(img_bg, img_fg) + return print_image + + @staticmethod + def get_print(print_dict): + if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + print_dict['scale'] = 0.3 + else: + print_dict['scale'] = print_dict['print_scale_list'][0] + + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL") + # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + if image.mode == "RGBA": + new_background = Image.new('RGB', image.size, (255, 255, 255)) + new_background.paste(image, mask=image.split()[3]) + image = new_background + print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + return print_dict + + def crop_image(self, image, image_size_h, image_size_w, location, print_shape): + print_w = print_shape[1] + print_h = print_shape[0] + + random.seed(self.random_seed) + # logging.info(f'overall print location : {location}') + # x_offset = random.randint(0, image.shape[0] - image_size_h) + # y_offset = random.randint(0, image.shape[1] - image_size_w) + + # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 + x_offset = print_w - int(location[0][1] % print_w) + y_offset = print_w - int(location[0][0] % print_h) + + # y_offset = int(location[0][0]) + # x_offset = int(location[0][1]) + + if len(image.shape) == 2: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w] + elif len(image.shape) == 3: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_low_high_lab(Lab_value, L=False): + if L: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + else: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + return high, low + + @staticmethod + def img_rotate(image, angel, scale): + """顺时针旋转图像任意角度 + + Args: + image (np.array): [原始图像] + angel (float): [逆时针旋转的角度] + + Returns: + [array]: [旋转后的图像] + """ + + h, w = image.shape[:2] + center = (w // 2, h // 2) + # if type(angel) is not int: + # angel = 0 + M = cv2.getRotationMatrix2D(center, -angel, scale) + # 调整旋转后的图像长宽 + rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0])))) + rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0])))) + M[0, 2] += (rotated_w - w) // 2 + M[1, 2] += (rotated_h - h) // 2 + # 旋转图像 + rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h)) + + return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) + # return rotated_img, (0, 0) + + @staticmethod + def rotate_crop_image(img, angle, crop): + """ + angle: 旋转的角度 + crop: 是否需要进行裁剪,布尔向量 + """ + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] + w, h = img.shape[:2] + # 旋转角度的周期是360° + angle %= 360 + # 计算仿射变换矩阵 + M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + # 得到旋转后的图像 + img_rotated = cv2.warpAffine(img, M_rotation, (w, h)) + + # 如果需要去除黑边 + if crop: + # 裁剪角度的等效周期是180° + angle_crop = angle % 180 + if angle > 90: + angle_crop = 180 - angle_crop + # 转化角度为弧度 + theta = angle_crop * np.pi / 180 + # 计算高宽比 + hw_ratio = float(h) / float(w) + # 计算裁剪边长系数的分子项 + tan_theta = np.tan(theta) + numerator = np.cos(theta) + np.sin(theta) * np.tan(theta) + + # 计算分母中和高宽比相关的项 + r = hw_ratio if h > w else 1 / hw_ratio + # 计算分母项 + denominator = r * tan_theta + 1 + # 最终的边长系数 + crop_mult = numerator / denominator + + # 得到裁剪区域 + w_crop = int(crop_mult * w) + h_crop = int(crop_mult * h) + x0 = int((w - w_crop) / 2) + y0 = int((h - h_crop) / 2) + + img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop) + + return img_rotated + + @staticmethod + def read_image(image_url): + image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" + return image, image_mode + + @staticmethod + def resize_and_crop(img, target_width, target_height): + # 获取原始图像的尺寸 + original_height, original_width = img.shape[:2] + + # 计算目标尺寸的宽高比 + target_ratio = target_width / target_height + + # 计算原始图像的宽高比 + original_ratio = original_width / original_height + + # 调整尺寸 + if original_ratio > target_ratio: + # 原始图像更宽,按高度resize,然后裁剪宽度 + new_height = target_height + new_width = int(original_width * (target_height / original_height)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪宽度 + start_x = (new_width - target_width) // 2 + cropped_img = resized_img[:, start_x:start_x + target_width] + else: + # 原始图像更高,按宽度resize,然后裁剪高度 + new_width = target_width + new_height = int(original_height * (target_width / original_width)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪高度 + start_y = (new_height - target_height) // 2 + cropped_img = resized_img[start_y:start_y + target_height, :] + + return cropped_img diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py new file mode 100644 index 0000000..d101530 --- /dev/null +++ b/app/service/design/items/pipelines/scale.py @@ -0,0 +1,56 @@ +import math + +import cv2 + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Scaling(object): + def __init__(self): + pass + + # @ RunTime + def __call__(self, result): + if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: + # milvus_db_keypoint_cache + distance_clo = math.sqrt( + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2 + + + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2) + + distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) + # distance_clo = math.sqrt( + # (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2 + # + + # (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2) + # + # distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) + if distance_clo == 0: + result['scale'] = 1 + else: + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'toe': + distance_bdy = math.sqrt( + (int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2 + + + (int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2 + ) + + Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0) + Edge = cv2.Canny(Blur, 10, 200) + Edge = cv2.dilate(Edge, None) + Edge = cv2.erode(Edge, None) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contours = sorted(Contour, key=cv2.contourArea, reverse=True) + + Max_contour = Contours[0] + x, y, w, h = cv2.boundingRect(Max_contour) + width = w + distance_clo = width + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'hand_point': + result['scale'] = result['scale_bag'] + elif result['keypoint'] == 'ear_point': + result['scale'] = result['scale_earrings'] + return result diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py new file mode 100644 index 0000000..d9f8ac0 --- /dev/null +++ b/app/service/design/items/pipelines/segmentation.py @@ -0,0 +1,14 @@ +from ..builder import PIPELINES +from ...utils.design_ensemble import get_seg_result + + +@PIPELINES.register_module() +class Segmentation(object): + def __init__(self, device='cpu', show=False, debug=None): + self.show = show + self.device = device + self.debug = debug + + def __call__(self, result): + result['seg_result'] = get_seg_result(result["image_id"], result['image']) + return result diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py new file mode 100644 index 0000000..1e06712 --- /dev/null +++ b/app/service/design/items/pipelines/split.py @@ -0,0 +1,81 @@ +import logging + +import cv2 +import numpy as np +from PIL import Image +from cv2 import cvtColor, COLOR_BGR2RGBA + +from app.service.utils.generate_uuid import generate_uuid +from ..builder import PIPELINES +from ...utils.conversion_image import rgb_to_rgba +from ...utils.upload_image import upload_png_mask + + +@PIPELINES.register_module() +class Split(object): + """ + Split image into front and back layer according to the segmentation result + """ + + # KNet + def __call__(self, result): + try: + if 'mask' not in result.keys(): + raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') + if 'seg_result' not in result.keys(): # 没过seg模型 + result['front_mask'] = result['mask'].copy() + result['back_mask'] = np.zeros_like(result['mask']) + else: + temp_front = result['seg_result'] == 1 + result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) + temp_back = result['seg_result'] == 2 + result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) + + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + if len(result['front_mask'].shape) > 2: + front_mask = result['front_mask'][0] + else: + front_mask = result['front_mask'] + + if len(result['back_mask'].shape) > 2: + back_mask = result['back_mask'][0] + else: + back_mask = result['back_mask'] + + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + result_front_image = np.zeros_like(rgba_image) + result_front_image[front_mask != 0] = rgba_image[front_mask != 0] + + result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1])) + result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) + # result['front_mask_image'] = cv2.resize(front_mask, front_new_size) + # result['front_image'] = result_front_image_pil + front_mask = cv2.resize(front_mask, front_new_size) + result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) + + if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): + result_back_image = np.zeros_like(rgba_image) + result_back_image[back_mask != 0] = rgba_image[back_mask != 0] + + result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) + back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1])) + result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) + # result['back_mask_image'] = cv2.resize(back_mask, back_new_size) + # result['back_image'] = result_back_image_pil + + back_mask = cv2.resize(back_mask, back_new_size) + result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) + else: + result['back_image'] = None + result["back_image_url"] = None + result["back_mask_url"] = None + result['back_mask_image'] = None + + # 创建中间图层 + result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask']) + result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) + _, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') + return result + except Exception as e: + logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/items/shoes.py b/app/service/design/items/shoes.py new file mode 100644 index 0000000..aa20d3c --- /dev/null +++ b/app/service/design/items/shoes.py @@ -0,0 +1,121 @@ +import cv2 +import numpy as np +from PIL import Image + +from .builder import ITEMS +from .clothing import Clothing +from ..utils.conversion_image import rgb_to_rgba +from ..utils.upload_image import upload_png_mask +from ...utils.generate_uuid import generate_uuid + + +@ITEMS.register_module() +class Shoes(Clothing): + # TODO location of shoes has little mismatch + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Shoes, self).__init__(**kwargs) + + def organize(self, layer): + left_shoe_mask, right_shoe_mask = self.cut() + + left_layer = dict(name=f'{type(self).__name__.lower()}_left', + image=self.result['shoes_left'], + image_url=self.result['left_image_url'], + mask_url=self.result['left_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=self.calculate_start_point(self.result['keypoint'], + self.result['scale'], + self.result['clothes_keypoint'], + self.result['body_point'], + 'left')) + layer.insert(left_layer) + + right_layer = dict(name=f'{type(self).__name__.lower()}_right', + image=self.result['shoes_right'], + image_url=self.result['right_image_url'], + mask_url=self.result['right_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=self.calculate_start_point(self.result['keypoint'], + self.result['scale'], + self.result['clothes_keypoint'], + self.result['body_point'], + 'right')) + + layer.insert(right_layer) + + def cut(self): + """ + Cut shoes mask into two pieces + Returns: + """ + contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = sorted(contour, key=cv2.contourArea, reverse=True) + + bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]] + (contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False)) + + epsilon_left = 0.001 * cv2.arcLength(contours[0], True) + + approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True) + mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8) + cv2.drawContours(mask_left, [approx_left], -1, 255, -1) + item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0) + + rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left) + result_image = np.zeros_like(rgba_image) + result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0] + result_left_image_pil = Image.fromarray(result_image, 'RGBA') + result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS) + self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}") + + epsilon_right = 0.001 * cv2.arcLength(contours[1], True) + approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True) + mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8) + cv2.drawContours(mask_right, [approx_right], -1, 255, -1) + item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0) + + rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right) + result_image = np.zeros_like(rgba_image) + result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0] + result_right_image_pil = Image.fromarray(result_image, 'RGBA') + result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS) + self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}") + + return item_mask_left, item_mask_right + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location): + """ + left shoes align left + right shoes align right + Args: + keypoint_type: string, "toe" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + location: string, indicates whether the start point belongs to right or left shoe + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + if location not in ['left', 'right']: + raise KeyError(f'location value must be left or right but got {location}') + side_indicator = f'{keypoint_type}_{location}' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale), + body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale)) + return start_point diff --git a/app/service/design/items/top.py b/app/service/design/items/top.py new file mode 100644 index 0000000..135328f --- /dev/null +++ b/app/service/design/items/top.py @@ -0,0 +1,46 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Top(Clothing): + def __init__(self, pipeline, **kwargs): + if pipeline is None: + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']), + dict(type='Painting', painting_flag=True), + dict(type='PrintPainting', print_flag=True), + # dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']), + dict(type='Scaling'), + dict(type='Split'), + ] + kwargs.update(pipeline=pipeline) + super(Top, self).__init__(**kwargs) + + +@ITEMS.register_module() +class Blouse(Top): + def __init__(self, pipeline=None, **kwargs): + super(Blouse, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Outwear(Top): + def __init__(self, pipeline=None, **kwargs): + super(Outwear, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Dress(Top): + def __init__(self, pipeline=None, **kwargs): + super(Dress, self).__init__(pipeline, **kwargs) + + +# Men's clothing +@ITEMS.register_module() +class Tops(Top): + def __init__(self, pipeline=None, **kwargs): + super(Tops, self).__init__(pipeline, **kwargs) diff --git a/app/service/design/model_process_service.py b/app/service/design/model_process_service.py new file mode 100644 index 0000000..076e04d --- /dev/null +++ b/app/service/design/model_process_service.py @@ -0,0 +1,28 @@ +import io + +from app.service.utils.oss_client import oss_get_image, oss_upload_image + + +def model_transpose(image_path): + bucket = image_path.split("/", 1)[0] + object_name = image_path.split("/", 1)[1] + new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL") + image = image.convert("RGBA") + data = image.getdata() + # + new_data = [] + for item in data: + if item[0] >= 256 and item[1] >= 256 and item[2] >= 256: + new_data.append((255, 255, 255, 0)) + else: + new_data.append(item) + image.putdata(new_data) + + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes) + image_path = f"{bucket}/{new_object_name}" + return image_path diff --git a/app/service/design/service.py b/app/service/design/service.py new file mode 100644 index 0000000..54cb45b --- /dev/null +++ b/app/service/design/service.py @@ -0,0 +1,134 @@ +import concurrent.futures + +from app.core.config import PRIORITY_DICT +from app.service.design.core.layer import Layer +from app.service.design.items import build_item +from app.service.design.utils.redis_utils import Redis +from app.service.design.utils.synthesis_item import synthesis, synthesis_single +from app.service.utils.decorator import RunTime + + +def process_item(item, layers): + # logging.info("process running.........") + item.process() + item.organize(layers) + if item.result['name'] == "mannequin": + return item.result['body_image'].size + + +def update_progress(process_id, total): + r = Redis() + progress = r.read(key=process_id) + if progress and total != 1: + if int(progress) <= 100: + r.write(key=process_id, value=int(progress) + int(100 / total)) + else: + r.write(key=process_id, value=100) + return progress + elif total == 1: + r.write(key=process_id, value=100) + return progress + else: + r.write(key=process_id, value=int(100 / total)) + return progress + + +def final_progress(process_id): + r = Redis() + progress = r.read(key=process_id) + r.write(key=process_id, value=100) + return progress + + +@RunTime +def generate(request_data): + return_response = {} + request_data = request_data.dict() + assert "process_id" in request_data.keys(), "Need process_id parameters" + + objects = request_data['objects'] + # insert_keypoint_cache(objects) + process_id = request_data['process_id'] + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交每个对象的处理任务 + futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)} + # 获取处理结果 + for future in concurrent.futures.as_completed(futures): + obj = futures[future] + + result = future.result() + return_response[obj] = result + final_progress(process_id) + return return_response + + +def process_object(cfg, process_id, total): + basic_info = cfg.get('basic') + items_response = { + 'layers': [] + } + if cfg.get('basic')['single_overall'] == 'overall': + basic_info['debug'] = False + items = [build_item(x, default_args=basic_info) for x in cfg.get('items')] + layers = Layer() + body_size = None + futures = [] + for item in items: + futures = [process_item(item, layers)] + for future in futures: + if future is not None: + body_size = future + # 是否自定义排序 + if basic_info.get('layer_order', False): + layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf'))) + else: + layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf'))) + # 合成 + items_response['synthesis_url'] = synthesis(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + + # 'image': lay['image'], + # 'mask_image': lay['mask_image'], + }) + elif cfg.get('basic')['single_overall'] == 'single': + assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters " + basic_info['debug'] = False + for item in cfg.get('items'): + if item['type'] == cfg.get('basic')['switch_category']: + item = build_item(item, default_args=cfg.get('basic')) + item.process() + items_response['layers'].append({ + 'image_category': f"{item.result['name']}_front", + 'image_size': item.result['back_image'].size if item.result['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item.result['front_image_url'], + 'mask_url': item.result['front_mask_url'], + "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "" + + }) + items_response['layers'].append({ + 'image_category': f"{item.result['name']}_back", + 'image_size': item.result['front_image'].size if item.result['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item.result['back_image_url'], + 'mask_url': item.result['back_mask_url'], + "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "" + + }) + items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image']) + break + update_progress(process_id, total) + return items_response diff --git a/app/service/design/utils/__init__.py b/app/service/design/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/design/utils/conversion_image.py b/app/service/design/utils/conversion_image.py new file mode 100644 index 0000000..77848cc --- /dev/null +++ b/app/service/design/utils/conversion_image.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :conversion_image.py +@Author :周成融 +@Date :2023/8/21 10:40:29 +@detail : +""" +import numpy as np + + +def rgb_to_rgba(rgb_size, rgb_image, mask): + alpha_channel = np.full(rgb_size, 255, dtype=np.uint8) + # 创建四通道的结果图像 + rgba_image = np.dstack((rgb_image, alpha_channel)) + alpha_channel = np.where(mask > 0, 255, 0) + # 更新RGBA图像的透明度通道 + rgba_image[:, :, 3] = alpha_channel + return rgba_image + + +if __name__ == '__main__': + image = open("") diff --git a/app/service/design/utils/design_ensemble.py b/app/service/design/utils/design_ensemble.py new file mode 100644 index 0000000..00d391f --- /dev/null +++ b/app/service/design/utils/design_ensemble.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :design_ensemble.py +@Author :周成融 +@Date :2023/8/16 19:36:21 +@detail :发起请求 获取推理结果 +""" +import logging + +import cv2 +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient + +from app.core.config import * + +""" + keypoint + 预处理 推理 后处理 +""" + + +def keypoint_preprocess(img_path): + img = mmcv.imread(img_path) + img_scale = (256, 256) + img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, (w_scale, h_scale) + + +# @ RunTime +# 推理 +def get_keypoint_result(image, site): + keypoint_result = None + try: + image, scale_factor = keypoint_preprocess(image) + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + transformed_img = image.astype(np.float32) + inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)] + results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs) + inference_output = torch.from_numpy(results.as_numpy(f'output')) + keypoint_result = keypoint_postprocess(inference_output, scale_factor) + except Exception as e: + logging.warning(f"get_keypoint_result : {e}") + return keypoint_result + + +def keypoint_postprocess(output, scale_factor): + max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2) + max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2) + segment_result = max_coords.numpy() + scale_factor = [1 / x for x in scale_factor[::-1]] + scale_matrix = np.diag(scale_factor) + nan = np.isinf(scale_matrix) + scale_matrix[nan] = 0 + return np.ceil(np.dot(segment_result, scale_matrix) * 4) + + +""" + seg + 预处理 推理 后处理 +""" + + +# KNet +def seg_preprocess(img_path): + img = mmcv.imread(img_path) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + scale_factor = [] + img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True) + scale_factor.append(x) + scale_factor.append(y) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + +# @ RunTime +def get_seg_result(image_id, image): + image, ori_shape = seg_preprocess(image) + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") + transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), + ] + results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = results.as_numpy(SEGMENTATION['output']) + seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape) + return seg_result + + +# no cache +def seg_postprocess(image_id, output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +def key_point_show(image_path, key_point_result=None): + img = cv2.imread(image_path) + points_list = key_point_result + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for point in points_list: + cv2.circle(img, point[::-1], point_size, point_color, thickness) + cv2.imshow("0", img) + cv2.waitKey(0) + + +if __name__ == '__main__': + image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif") + a = get_keypoint_result(image, "up") + new_list = [] + print(list) + for i in a[0]: + new_list.append((int(i[0]), int(i[1]))) + key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list) + # a = get_seg_result(1, image) + print(a) diff --git a/app/service/design/utils/redis_utils.py b/app/service/design/utils/redis_utils.py new file mode 100644 index 0000000..012fbe0 --- /dev/null +++ b/app/service/design/utils/redis_utils.py @@ -0,0 +1,99 @@ +import redis + +from app.core.config import REDIS_HOST, REDIS_PORT + + +class Redis(object): + """ + redis数据库操作 + """ + + @staticmethod + def _get_r(): + host = REDIS_HOST + port = REDIS_PORT + db = 0 + r = redis.StrictRedis(host, port, db) + return r + + @classmethod + def write(cls, key, value, expire=None): + """ + 写入键值对 + """ + # 判断是否有过期时间,没有就设置默认值 + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.set(key, value, ex=expire_in_seconds) + + @classmethod + def read(cls, key): + """ + 读取键值对内容 + """ + r = cls._get_r() + value = r.get(key) + return value.decode('utf-8') if value else value + + @classmethod + def hset(cls, name, key, value): + """ + 写入hash表 + """ + r = cls._get_r() + r.hset(name, key, value) + + @classmethod + def hget(cls, name, key): + """ + 读取指定hash表的键值 + """ + r = cls._get_r() + value = r.hget(name, key) + return value.decode('utf-8') if value else value + + @classmethod + def hgetall(cls, name): + """ + 获取指定hash表所有的值 + """ + r = cls._get_r() + return r.hgetall(name) + + @classmethod + def delete(cls, *names): + """ + 删除一个或者多个 + """ + r = cls._get_r() + r.delete(*names) + + @classmethod + def hdel(cls, name, key): + """ + 删除指定hash表的键值 + """ + r = cls._get_r() + r.hdel(name, key) + + @classmethod + def expire(cls, name, expire=None): + """ + 设置过期时间 + """ + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.expire(name, expire_in_seconds) + + +if __name__ == '__main__': + redis_client = Redis() + # print(redis_client.write(key="1230", value=0)) + redis_client.write(key="1230", value=10) + # print(redis_client.read(key="1230")) diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py new file mode 100644 index 0000000..dc8e427 --- /dev/null +++ b/app/service/design/utils/synthesis_item.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :synthesis_item.py +@Author :周成融 +@Date :2023/8/26 14:13:04 +@detail : +""" +import io +import logging + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.oss_client import oss_upload_image + + +def positioning(all_mask_shape, mask_shape, offset): + all_start = 0 + all_end = 0 + mask_start = 0 + mask_end = 0 + if offset == 0: + all_start = 0 + all_end = min(all_mask_shape, mask_shape) + + mask_start = 0 + mask_end = min(all_mask_shape, mask_shape) + elif offset > 0: + all_start = min(offset, all_mask_shape) + all_end = min(offset + mask_shape, all_mask_shape) + + mask_start = 0 + mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape) + elif offset < 0: + if abs(offset) > mask_shape: + all_start = 0 + all_end = 0 + else: + all_start = 0 + if mask_shape - abs(offset) > all_mask_shape: + all_end = min(mask_shape - abs(offset), all_mask_shape) + else: + all_end = mask_shape - abs(offset) + + if abs(offset) > mask_shape: + mask_start = mask_shape + mask_end = mask_shape + else: + mask_start = abs(offset) + if mask_shape - abs(offset) >= all_mask_shape: + mask_end = all_mask_shape + abs(offset) + else: + mask_end = mask_shape + return all_start, all_end, mask_start, mask_end + + +# @RunTime +def synthesis(data, size): + # 创建底图 + base_image = Image.new('RGBA', size, (0, 0, 0, 0)) + try: + + all_mask_shape = (size[1], size[0]) + top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) + bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) + + top = True + bottom = True + i = len(data) + while i: + i -= 1 + if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]: + top = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] + elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]: + bottom = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] + elif bottom is False and top is False: + break + + all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + + for layer in data: + if layer['image'] is not None: + if layer['name'] != "body": + test_image = Image.new('RGBA', size, (0, 0, 0, 0)) + test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) + mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) + mask_alpha = Image.fromarray(mask_data) + cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) + base_image.paste(cropped_image, (0, 0), cropped_image) + else: + base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) + + result_image = base_image + + with io.BytesIO() as output: + result_image.save(output, format='PNG') + data = output.getvalue() + + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + + # oss upload + image_bytes = image_data.read() + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + + except Exception as e: + logging.warning(f"synthesis runtime exception : {e}") + + +def synthesis_single(front_image, back_image): + result_image = None + if front_image: + result_image = front_image + if back_image: + result_image.paste(back_image, (0, 0), back_image) + + # with io.BytesIO() as output: + # result_image.save(output, format='PNG') + # data = output.getvalue() + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # oss upload + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py new file mode 100644 index 0000000..3571816 --- /dev/null +++ b/app/service/design/utils/upload_image.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :upload_image.py +@Author :周成融 +@Date :2023/8/28 13:49:20 +@detail : +""" +import io +import logging + +import cv2 + +from app.core.config import * +from app.service.utils.oss_client import oss_upload_image + + +# @RunTime +def upload_png_mask(front_image, object_name, mask=None): + try: + mask_url = None + if mask is not None: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + # image_bytes = io.BytesIO() + # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + # image_bytes.seek(0) + # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + # oss upload #################### + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" + + image_data = io.BytesIO() + front_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" + return front_image, image_url, mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py new file mode 100644 index 0000000..f69c3ee --- /dev/null +++ b/app/service/design_pre_processing/service.py @@ -0,0 +1,374 @@ +import logging +import time + +import cv2 +import numpy as np +import torch +import tritonclient.grpc as grpcclient +from urllib3.exceptions import ResponseError + +from app.core.config import * +from app.schemas.pre_processing import DesignPreProcessingModel +from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.utils.oss_client import oss_get_image, oss_upload_image + + +class DesignPreprocessing: + # def __init__(self): + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # @ RunTime + def pipeline(self, image_list): + sketches_list = self.read_image(image_list) + logging.info("read image success") + + bounding_box_sketches_list = self.bounding_box(sketches_list) + logging.info("bounding box image success") + + super_resolution_list = self.super_resolution(bounding_box_sketches_list) + logging.info("super_resolution_list image success") + + infer_sketches_list = self.infer_image(super_resolution_list) + logging.info("infer image success") + + result = self.composing_image(infer_sketches_list) + logging.info("Replenish white edge image success") + + for d in result: + if 'image_obj' in d: + del d['image_obj'] + if 'obj' in d: + del d['obj'] + if 'keypoint_result' in d: + del d['keypoint_result'] + return result + + def read_image(self, image_list): + for obj in image_list: + # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif image.shape[2] == 4: # 如果是四通道 mask + image = image[:, :, :3] + obj["image_obj"] = image + return image_list + + # @ RunTime + def bounding_box(self, image_list): + for item in image_list: + image = item['image_obj'] + # 使用Canny边缘检测来检测物体的轮廓 + edges = cv2.Canny(image, 50, 150) + # 查找轮廓 + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 初始化包围所有外接矩形的大矩形的坐标 + x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 + # 遍历所有外接矩形,更新大矩形的坐标 + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + x_min = min(x_min, x) + y_min = min(y_min, y) + x_max = max(x_max, x + w) + y_max = max(y_max, y + h) + + if IF_DEBUG_SHOW: + image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) + cv2.imshow("bounding_box image", image_with_big_rect) + cv2.waitKey(0) + + # 根据大矩形的坐标来裁剪原始图像 + if len(contours) > 0: + cropped_image = image[y_min:y_max, x_min:x_max] + item['obj'] = cropped_image # 新shape图像 + # 取消直接覆盖,新增size判断 + # try: + # # 覆盖到minio + # image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes() + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + # except ResponseError as err: + # print(f"Error: {err}") + else: + item['obj'] = image + return image_list + + def super_resolution(self, image_list): + for item in image_list: + # 判断 两边是否同时都小于512 因为此处做四倍超分 + if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: + # 如果任意一边小于256则超分 + if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256: + # 超分 + img = item['obj'].astype(np.float32) / 255. + sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) + sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() + inputs = [ + grpcclient.InferInput("input", sample.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(sample) + triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) + result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) + result_image = result.as_numpy(f'output')[0] + sr_output = torch.tensor(result_image) + output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) + item['obj'] = output + try: + # 覆盖到minio + image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + bucket_name = item['image_url'].split("/", 1)[0] + object_name = item['image_url'].split("/", 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + except ResponseError as err: + print(f"Error: {err}") + return image_list + + # @ RunTime + def infer_image(self, image_list): + for sketch in image_list: + # 小写 + image_category = sketch['image_category'].lower() + # 判断上下装 + sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # 推理得到keypoint + sketch['keypoint_result'] = self.keypoint_cache(sketch) + + if IF_DEBUG_SHOW: + debug_show_image = sketch['obj'].copy() + points_list = [] + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for i in sketch['keypoint_result'].values(): + points_list.append((int(i[1]), int(i[0]))) + for point in points_list: + cv2.circle(debug_show_image, point, point_size, point_color, thickness) + cv2.imshow("", debug_show_image) + cv2.waitKey(0) + # # 关键点在上部则推理seg + # if sketch["site"] == "up": + # # 判断seg缓存是否存在,是否与当前图片shape一致 + # seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape) + # if seg_result is False: + # # 推理seg + 保存 + # seg_result = get_seg_result(sketch['image_id'], sketch['obj']) + return image_list + + # @ RunTime + def composing_image(self, image_list): + for image in image_list: + ''' 比例相同 整合上下装代码''' + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" + else: + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" + + # if image['site'] == 'down': + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + return image_list + + @staticmethod + def select_seg_result(image_id, image_obj): + try: + # 如果shape不匹配 返回false + result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) + if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: + return result + else: + return False + except FileNotFoundError as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + @staticmethod + def search_seg_result(image_id, ori_shape): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. + # collection.load() + # start_time = time.time() + # res = collection.query( + # expr=f"seg_id == {image_id}", + # offset=0, + # limit=10, + # output_fields=["seg_cache"], + # ) + # logging.info(f"search seg cache time : {time.time() - start_time}") + + # if len(res): + # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) + # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) + # array_2d_exact = array_2d_exact.squeeze().numpy() + # return array_2d_exact + # else: + return False + except Exception as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + def keypoint_cache(self, sketch): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # collection.load() + start_time = time.time() + # res = collection.query( + # expr=f"keypoint_id == {sketch['image_id']}", + # offset=0, + # limit=1, + # output_fields=["keypoint_cache", "keypoint_site"], + # ) + res = [] + logging.info(f"search keypoint time : {time.time() - start_time}") + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.save_keypoint_cache(sketch, keypoint_infer_result) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != sketch['site']: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector']) + except Exception as e: + logging.info(f"search keypoint cache milvus error {e}") + return False + + # @ RunTime + def infer_keypoint_result(self, sketch): + keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果 + return keypoint_infer_result + + @staticmethod + # @ RunTime + def save_keypoint_cache(sketch, keypoint_infer_result): + if sketch['site'] == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, keypoint_infer_result.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([keypoint_infer_result.flatten(), zeros]) + data = [ + [int(sketch['image_id'])], + [sketch['site']], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.insert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(sketch, infer_result, search_result): + if sketch['site'] == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + [int(sketch['image_id'])], + ["all"], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.upsert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + +if __name__ == '__main__': + data = { + "sketches": [ + { + "image_category": "dress", + "image_id": "107903", + "image_url": "aida-sys-image/images/female/dress/0628000000.jpg" + } + ] + } + request_data = DesignPreProcessingModel(sketches=data["sketches"]) + server = DesignPreprocessing() + data = server.pipeline(image_list=request_data.sketches) + print(data) diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service_generate_image.py similarity index 86% rename from app/service/generate_image/service.py rename to app/service/generate_image/service_generate_image.py index 6f8d092..dac211c 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service_generate_image.py @@ -10,21 +10,19 @@ import json import logging import time -from io import BytesIO import cv2 import minio +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np -from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast -from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -36,22 +34,23 @@ class GenerateImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) - self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - self.prompt = request_data.prompt + self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category + if self.category == "sketch": + self.prompt = f"{self.category},{self.prompt}" self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} @@ -63,10 +62,13 @@ class GenerateImage: # Read data from response. # read image use cv2 try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + + image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -104,7 +106,7 @@ class GenerateImage: image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" @@ -121,13 +123,6 @@ class GenerateImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -147,7 +142,7 @@ class GenerateImage: input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: @@ -187,9 +182,10 @@ if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', - image_url="", + image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', - category="test" + category="test", + gender="male" ) server = GenerateImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py new file mode 100644 index 0000000..5ea6f83 --- /dev/null +++ b/app/service/generate_image/service_generate_product_image.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import cv2 +import numpy as np +import redis +import tritonclient.grpc as grpcclient +from PIL import Image, ImageOps +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateProductImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateProductImage: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "product_image" + self.image_strength = request_data.image_strength + self.batch_size = 1 + self.product_type = request_data.product_type + self.prompt = request_data.prompt + self.image, self.image_size = pre_processing_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def callback(self, result, error): + if error: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + else: + # pil图像转成numpy数组 + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + prompts = [self.prompt] * self.batch_size + self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + self.image = cv2.resize(self.image, (512, 768)) + images = [self.image.astype(np.uint8)] * self.batch_size + + if self.product_type == "single": + text_obj = np.array(prompts, dtype="object").reshape(-1, 1) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + else: + text_obj = np.array(prompts, dtype="object").reshape(1) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + + # 假设 prompts、images 和 self.image_strength 已经定义 + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + inputs = [input_text, input_image, input_image_strength] + input_image_strength.set_data_from_numpy(image_strength_obj) + + if self.product_type == "single": + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) + return data + + +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 原始图片的尺寸 + width, height = image.size + + # 计算长宽比为 3:2 的新尺寸 + desired_ratio = 2 / 3 + current_ratio = width / height + + if current_ratio > desired_ratio: + # 原始图片更宽,需要在上下添加 padding + new_width = width + new_height = int(width / desired_ratio) + else: + # 原始图片更高或者长宽比已经为 3:2 + new_height = height + new_width = int(height * desired_ratio) + + # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 + pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + + # 将原始图片粘贴到新的画布中心 + left = (new_width - width) // 2 + top = (new_height - height) // 2 + pad_image.paste(image, (left, top)) + + # 将画布 resize 成宽度 500,长度 750 + resized_image = pad_image.resize((500, 750)) + image_size = (512, 768) + + if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): + # 创建白色背景 + background = Image.new("RGB", image_size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(resized_image, mask=resized_image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, image_size + + +if __name__ == '__main__': + rd = GenerateProductImageModel( + tasks_id="123-89", + # prompt="", + image_strength=0.9, + prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + product_type="overall" + ) + server = GenerateProductImage(rd) + print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py new file mode 100644 index 0000000..e0729ba --- /dev/null +++ b/app/service/generate_image/service_generate_relight_image.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import cv2 +import numpy as np +import redis +import tritonclient.grpc as grpcclient +from PIL import Image +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateRelightImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateRelightImage: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "relight_image" + self.batch_size = 1 + self.prompt = request_data.prompt + self.seed = "1" + self.product_type = request_data.product_type + self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.direction = request_data.direction + self.image_url = request_data.image_url + self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def callback(self, result, error): + if error: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + else: + # pil图像转成numpy数组 + if self.product_type == 'single': + image = result.as_numpy("generated_relight_image") + else: + image = result.as_numpy("generated_inpaint_image") + + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + prompts = [self.prompt] * self.batch_size + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + seeds = [self.seed] * self.batch_size + nagetive_prompts = [self.negative_prompt] * self.batch_size + directions = [self.direction] * self.batch_size + + if self.product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) + + inputs = [input_text, input_natext, input_image, input_seed, input_direction] + if self.product_type == 'single': + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) + return data + + +if __name__ == '__main__': + rd = GenerateRelightImageModel( + tasks_id="123-89", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="Colorful black", + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + direction="Right Light", + product_type="single" + ) + server = GenerateRelightImage(rd) + print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py new file mode 100644 index 0000000..e3def3e --- /dev/null +++ b/app/service/generate_image/service_generate_single_logo.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import cv2 +import numpy as np +import redis +from PIL import Image +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +import tritonclient.grpc as grpcclient +from app.schemas.generate_image import GenerateSingleLogoImageModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image + +logger = logging.getLogger() + + +class GenerateSingleLogoImage: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.batch_size = 1 + self.category = "single_logo" + self.negative_prompts = "bad, ugly" + self.seed = request_data.seed + self.tasks_id = request_data.tasks_id + self.prompt = request_data.prompt + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + self.redis_client.expire(self.tasks_id, 600) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def callback(self, result, error): + if error: + self.gen_single_logo_data['status'] = "FAILURE" + self.gen_single_logo_data['message'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + else: + image = result.as_numpy("generated_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + self.gen_single_logo_data['status'] = "SUCCESS" + self.gen_single_logo_data['message'] = "success" + self.gen_single_logo_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + + def get_result(self): + try: + # prompt + prompts = [self.prompt] * self.batch_size + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + + text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1)) + input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)) + input_text_neg.set_data_from_numpy(text_obj_neg) + + seed = np.array(self.seed, dtype="object").reshape((-1, 1)) + input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) + input_seed.set_data_from_numpy(seed) + inputs = [input_text, input_text_neg, input_seed] + ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback) + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateSingleLogoImageModel( + tasks_id="123-89", + prompt='an apple', + seed="2", + ) + server = GenerateSingleLogoImage(rd) + print(server.get_result()) diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index f6d87f2..dd64ace 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -381,7 +381,7 @@ if __name__ == '__main__': remove_bg_img = remove_background(luminance) # cv2.imwrite("remove_bg_img.png", remove_bg_img) - print(1) + # print(1) cv2.imshow("source", img) cv2.imshow("levels", equAuto) cv2.imshow("luminance", luminance) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 0e8e542..a63488c 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -10,26 +10,61 @@ import io import logging +# import boto3 import cv2 from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_upload_image minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) -def upload_png_sd(image, user_id, category, object_name): +# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + + +# def upload_single_logo(image, user_id, category, object_name): +# with io.BytesIO() as output: +# image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 +# try: +# key = f'{user_id}/{category}/{object_name}' +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png') +# return image_url +# except Exception as e: +# print(f'上传到 S3 失败: {e}') + +def upload_SDXL_image(image, user_id, category, file_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + + # minio_req = minio_client.put_object( + # GI_MINIO_BUCKET, + # f'{user_id}/{category}/{file_name}', + # io.BytesIO(image_bytes), + # len(image_bytes), + # content_type='image/jpeg' + # ) + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +def upload_png_sd(image, user_id, category, file_name): try: _, img_byte_array = cv2.imencode('.jpg', image) - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(img_byte_array), - len(img_byte_array), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py new file mode 100644 index 0000000..4ade635 --- /dev/null +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -0,0 +1,77 @@ +import logging + +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate + +from app.core.config import OPENAI_MODEL, OPENAI_API_KEY + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" + + +llm = ChatOpenAI(model_name=OPENAI_MODEL, + openai_api_key=OPENAI_API_KEY, + temperature=0) + + +def translate_to_en(text): + template = ( + """You are a translation expert, proficient in various languages. + And can translate various languages into English. + Please translate to grammatically correct English regardless of the input language. + If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + output the input text exactly as it is without any modifications or additions. + If there are grammatical errors, correct them and then output the sentence.""" + ) + system_message_prompt = SystemMessagePromptTemplate.from_template(template) + + # 待翻译文本由 Human 角色输入 + human_template = "User input : {text}" + human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template) + + # 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate + chat_prompt_template = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template) + + result = translate_chain.invoke(text) + + logging.info("translate result : " + result.get('text')) + # print("translate result : " + result.get('text')) + return result.get('text') + + # template = ( + # """ + # Input sentence: + # {translate} + # 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images, + # ensuring all key nouns or adjectives related to the image are retained. + # 2. Simplify complex sentence structures and clarify ambiguous expressions. + # 3. Only Output the adjusted English sentence. + # + # Output : + # """ + # ) + # # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words." + # # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding. + # + # prompt_template = PromptTemplate(input_variables=["translate"], template=template) + # prompt_chain = LLMChain(llm=llm, prompt=prompt_template) + # + # from langchain.chains import SimpleSequentialChain + # overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True) + # + # response = overall_chain.run(text) + # return response + + +def main(): + """Main function""" + text = translate_to_en("fire") + print(text) + + +if __name__ == "__main__": + main() diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index e87f1a7..c2cf39d 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,17 +1,17 @@ -import io +import json import logging import time -import minio.error -import redis -import json + import cv2 +import minio.error import numpy as np +import redis import torch import tritonclient.grpc as grpcclient -from minio import Minio + from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel -from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() @@ -24,7 +24,7 @@ class SuperResolution: self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) @@ -33,36 +33,37 @@ class SuperResolution: # @RunTime def read_image(self): try: - image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") + img = img.astype(np.float32) / 255. # 解码 except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") - img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img + # try: + # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + # except minio.error.S3Error as e: + # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) + # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + # logger.info(f" [x] Sent {sr_data}") + # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") + # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 + # return img + def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data - # @RunTime - def infer(self, inputs): - return self.triton_client.async_infer( - model_name=SR_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - # @RunTime def sr_result(self): sample = self.read_image() if self.sr_xn == 2: new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn) sample = cv2.resize(sample, new_shape) - print(new_shape) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ @@ -72,13 +73,16 @@ class SuperResolution: # , binary_data=True ) - ctx = self.infer(inputs) + ctx = self.triton_client.async_infer( + model_name=SR_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() - # noinspection PyTypeChecker self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break @@ -88,28 +92,19 @@ class SuperResolution: time.sleep(1) return self.read_tasks_status() - # results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) - - # sr_output = torch.from_numpy(results.as_numpy(f"output")) - # output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - # if output.ndim == 3: - # output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR - # output = (output * 255.0).round().astype(np.uint8) - # output_url = self.upload_img_sr(output, generate_uuid()) - # return output_url - def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() - res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') - image_url = f"aida-users/{res.object_name}" + # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') + object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' + oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"{SR_MINIO_BUCKET}/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") def callback(self, result, error): if error: - print(error) sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: @@ -135,6 +130,6 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") + request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556") service = SuperResolution(request_data) result_url = service.sr_result() diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py new file mode 100644 index 0000000..c2bb82c --- /dev/null +++ b/app/service/utils/oss_client.py @@ -0,0 +1,74 @@ +import io +import logging +from io import BytesIO + +import boto3 +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.core.config import * + +logger = logging.getLogger() + + +# 获取图片 +def oss_get_image(bucket, object_name, data_type): + # cv2 默认全通道读取 + image_object = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +def oss_upload_image(bucket, object_name, image_bytes): + req = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + # url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" + # url = "aida-users/89/single_logo/123-89.png" + # url = "aida-users/89/product_image/string-89.png" + url = "aida-results/result_c6520ce7-33a1-11ef-a8d3-b0dcefbff887.png" + read_type = "PIL" + if read_type == "cv2": + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show() diff --git a/logging_env.py b/logging_env.py index d1ac9bc..08873b0 100644 --- a/logging_env.py +++ b/logging_env.py @@ -1,51 +1,51 @@ from app.core.config import LOGS_PATH LOGGER_CONFIG_DICT = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"} + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'} }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "level": "DEBUG", - "formatter": "simple", - "stream": "ext://sys.stdout", + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'stream': 'ext://sys.stdout', }, - "info_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "INFO", - "formatter": "simple", - "filename": f"{LOGS_PATH}info.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'info_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}info.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, - "error_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "ERROR", - "formatter": "simple", - "filename": f"{LOGS_PATH}error.log", - "maxBytes": 10485760, - "backupCount": 20, - "encoding": "utf8", + 'error_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'ERROR', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}error.log', + 'maxBytes': 10485760, + 'backupCount': 20, + 'encoding': 'utf8', }, - "debug_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "DEBUG", - "formatter": "simple", - "filename": f"{LOGS_PATH}debug.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'debug_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'DEBUG', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}debug.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, }, - "loggers": { - "my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"} + 'loggers': { + 'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'} }, - "root": { - "level": "INFO", - "handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"], + 'root': { + 'level': 'INFO', + 'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'], }, } diff --git a/requirements.txt b/requirements.txt index 1529082..7b3fa73 100644 Binary files a/requirements.txt and b/requirements.txt differ