From a9dcd444c82bbf4fb158211d19d578c0421b663f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 28 May 2024 15:22:11 +0800 Subject: [PATCH 001/108] =?UTF-8?q?feat=20design=20=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 28 + app/api/api_route.py | 2 + app/core/config.py | 7 +- app/schemas/design.py | 50 ++ app/service/design/core/__init__.py | 0 app/service/design/core/layer.py | 116 +++ app/service/design/core/priority.py | 45 ++ app/service/design/design_request.json | 101 +++ app/service/design/design_request_2.json | 684 ++++++++++++++++++ app/service/design/fastapi_request.json | 69 ++ app/service/design/items/__init__.py | 16 + app/service/design/items/accessories.py | 59 ++ app/service/design/items/bag.py | 44 ++ app/service/design/items/body.py | 35 + app/service/design/items/bottom.py | 38 + app/service/design/items/builder.py | 9 + app/service/design/items/clothing.py | 96 +++ .../design/items/pipelines/__init__.py | 19 + app/service/design/items/pipelines/compose.py | 36 + .../items/pipelines/contour_detection.py | 59 ++ .../design/items/pipelines/keypoints.py | 148 ++++ app/service/design/items/pipelines/loading.py | 143 ++++ .../design/items/pipelines/painting.py | 498 +++++++++++++ app/service/design/items/pipelines/scale.py | 54 ++ .../design/items/pipelines/segmentation.py | 14 + app/service/design/items/pipelines/split.py | 115 +++ app/service/design/items/shoes.py | 126 ++++ app/service/design/items/top.py | 46 ++ app/service/design/service.py | 130 ++++ app/service/design/utils/__init__.py | 0 app/service/design/utils/conversion_image.py | 23 + app/service/design/utils/design_ensemble.py | 138 ++++ app/service/design/utils/redis_utils.py | 99 +++ app/service/design/utils/synthesis_item.py | 174 +++++ app/service/design/utils/upload_image.py | 160 ++++ 35 files changed, 3378 insertions(+), 3 deletions(-) create mode 100644 app/api/api_design.py create mode 100644 app/schemas/design.py create mode 100644 app/service/design/core/__init__.py create mode 100644 app/service/design/core/layer.py create mode 100644 app/service/design/core/priority.py create mode 100644 app/service/design/design_request.json create mode 100644 app/service/design/design_request_2.json create mode 100644 app/service/design/fastapi_request.json create mode 100644 app/service/design/items/__init__.py create mode 100644 app/service/design/items/accessories.py create mode 100644 app/service/design/items/bag.py create mode 100644 app/service/design/items/body.py create mode 100644 app/service/design/items/bottom.py create mode 100644 app/service/design/items/builder.py create mode 100644 app/service/design/items/clothing.py create mode 100644 app/service/design/items/pipelines/__init__.py create mode 100644 app/service/design/items/pipelines/compose.py create mode 100644 app/service/design/items/pipelines/contour_detection.py create mode 100644 app/service/design/items/pipelines/keypoints.py create mode 100644 app/service/design/items/pipelines/loading.py create mode 100644 app/service/design/items/pipelines/painting.py create mode 100644 app/service/design/items/pipelines/scale.py create mode 100644 app/service/design/items/pipelines/segmentation.py create mode 100644 app/service/design/items/pipelines/split.py create mode 100644 app/service/design/items/shoes.py create mode 100644 app/service/design/items/top.py create mode 100644 app/service/design/service.py create mode 100644 app/service/design/utils/__init__.py create mode 100644 app/service/design/utils/conversion_image.py create mode 100644 app/service/design/utils/design_ensemble.py create mode 100644 app/service/design/utils/redis_utils.py create mode 100644 app/service/design/utils/synthesis_item.py create mode 100644 app/service/design/utils/upload_image.py diff --git a/app/api/api_design.py b/app/api/api_design.py new file mode 100644 index 0000000..0c48c81 --- /dev/null +++ b/app/api/api_design.py @@ -0,0 +1,28 @@ +import logging +import time + +from fastapi import APIRouter + +from app.schemas.design import DesignModel +from app.service.design.service import generate + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/design") +def design(request_data: DesignModel): + try: + logger.info(f"design request item is : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + data = generate(request_data=request_data) + logger.info(f"design Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"design Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} \ No newline at end of file diff --git a/app/api/api_route.py b/app/api/api_route.py index 2513204..ff21b34 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -4,6 +4,7 @@ from app.api import api_test from app.api import api_super_resolution from app.api import api_generate_image from app.api import api_attribute_retrieve +from app.api import api_design router = APIRouter() @@ -11,3 +12,4 @@ router.include_router(api_test.router, tags=["test"], prefix="/test") router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api") router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api") router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api") +router.include_router(api_design.router, tags=['design'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index b029c1e..6e22adc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -80,15 +80,17 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' SEGMENTATION = { + "new_model_name": "seg_knet", "name": "seg_ocrnet_hr18", "input": "seg_input__0", "output": "seg_output__0", } # DESIGN config -DESIGN_MODEL_URL = '10.1.1.240:9000' - +DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" +KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', + 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') # 优先级 PRIORITY_DICT = { @@ -116,4 +118,3 @@ PRIORITY_DICT = { 'bag_back': -98, 'earring_back': -99, } - diff --git a/app/schemas/design.py b/app/schemas/design.py new file mode 100644 index 0000000..b203970 --- /dev/null +++ b/app/schemas/design.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel + + +# class BodyPointModel(BaseModel): +# waistband_right: list[int] +# hand_point_right: list[int] +# waistband_left: list[int] +# hand_point_left: list[int] +# shoulder_left: list[int] +# shoulder_right: list[int] +# +# +# class BasicModel(BaseModel): +# body_point: BodyPointModel +# layer_order: bool +# scale_bag: float +# scale_earrings: float +# self_template: bool +# single_overall: str +# switch_category: str +# body_path: str +# +# +# class PrintModel(BaseModel): +# if_single: bool +# print_path_list: list[str] +# +# +# class ItemModel(BaseModel): +# color: str +# image_id: str +# offset: list[int] +# path: str +# print: PrintModel +# resize_scale: float +# type: str +# +# +# class CollocationModel(BaseModel): +# basic: BasicModel +# item: list[ItemModel] +# +# +# class DesignModel(BaseModel): +# object: list[CollocationModel] +# process_id: str + +class DesignModel(BaseModel): + objects: list[dict] + process_id: str diff --git a/app/service/design/core/__init__.py b/app/service/design/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/design/core/layer.py b/app/service/design/core/layer.py new file mode 100644 index 0000000..0628851 --- /dev/null +++ b/app/service/design/core/layer.py @@ -0,0 +1,116 @@ +import logging + +import numpy as np +import cv2 +from matplotlib import pyplot as plt + +from PIL import Image + + +def show(img, win_name="temp"): + cv2.imshow(win_name, img) + cv2.waitKey(0) + + +def crop(img): + mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2) + img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340] + return img_roi + + +class Layer(object): + def __init__(self): + self._layer = [] + + @property + def layer(self): + return self._layer + + def insert(self, layer_instance): + if layer_instance['name'] == 'body': + self._body = layer_instance + self._layer.append(layer_instance) + + def sort(self, priority): + self._layer.sort(key=lambda x: priority[x['name']]) + + # def merge(self, cfg): + # """ + # opencv shape order (height, width, channel) + # image coordinate system: + # |------------->x (width) + # | + # | + # | + # y (height) + # Returns: + # + # + # """ + # base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0)) + # for layer in self._layer: + # y, x = layer['position'] + # base_image.paste(layer['image'], (x, y), layer['image']) + # # base_image.show() + # + # for x in self._layer: + # if np.all(x['mask'] == 0): + # continue + # # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask) + # roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x) + # temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2) + # tmp1 = (roi * (temp_bg / 255)).astype(np.uint8) + # temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2) + # tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8) + # + # roi[:] = cv2.add(tmp1, tmp2) + # # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA), + # # win_name=x.get('name')) + # # crop image and get the central part + # if cfg.get('basic')['self_template'] == False: + # dst_roi = crop(dst) + # else: + # dst_roi = dst + # return dst_roi, signal + # + # @staticmethod + # def get_roi(dst, image): + # signal = False + # dst_y, dst_x = dst.shape[:2] + # roi_height, roi_width = image['mask'].shape + # roi_y0, roi_x0 = image['position'] + # + # if roi_y0 < 0: + # roi_yin = 0 + # mask_yin = -roi_y0 + # signal = True + # else: + # roi_yin = roi_y0 + # mask_yin = 0 + # if roi_y0 + roi_height > dst_y: + # roi_yout = dst_y + # mask_yout = dst_y - roi_y0 + # signal = True + # else: + # roi_yout = roi_height + roi_y0 + # mask_yout = roi_height + # # x part + # if roi_x0 < 0: + # roi_xin = 0 + # mask_xin = -roi_x0 + # signal = True + # else: + # roi_xin = roi_x0 + # mask_xin = 0 + # if roi_x0 + roi_width > dst_x: + # roi_xout = dst_x + # mask_xout = dst_x - roi_x0 + # signal = True + # else: + # roi_xout = roi_width + roi_x0 + # mask_xout = roi_width + # + # roi = dst[roi_yin: roi_yout, roi_xin: roi_xout] + # roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout] + # roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout] + # return roi, roi_mask, roi_image, signal diff --git a/app/service/design/core/priority.py b/app/service/design/core/priority.py new file mode 100644 index 0000000..dc111ea --- /dev/null +++ b/app/service/design/core/priority.py @@ -0,0 +1,45 @@ +class Priority(object): + """Item layer priority levels. + """ + + def __init__(self, item_list): + self._priority = dict( + earring_front=99, + bag_front=98, + hairstyle_front=97, + outwear_front=20, + bottoms_front=19, + dress_front=18, + blouse_front=17, + skirt_front=16, + trousers_front=15, + tops_front=14, + shoes_right=1, + shoes_left=1, + body=0, + tops_back=-14, + trousers_back=-15, + skirt_back=-16, + blouse_back=-17, + dress_back=-18, + bottoms_back=-19, + outwear_back=-20, + hairstyle_back=-97, + bag_back=-98, + earring_back=-99, + ) + self.clothing_start_num = 10 + if not isinstance(item_list, list): + raise ValueError('item_list must be a list!') + for cate in item_list: + cate = cate.lower() + if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + raise ValueError(f'Item type error. Cannot recognize {cate}') + for i, cate in enumerate(item_list): + cate = cate.lower() + self._priority[f'{cate}_front'] = self.clothing_start_num - i + self._priority[f'{cate}_back'] = -(self.clothing_start_num - i) + + @property + def priority(self): + return self._priority diff --git a/app/service/design/design_request.json b/app/service/design/design_request.json new file mode 100644 index 0000000..5551b82 --- /dev/null +++ b/app/service/design/design_request.json @@ -0,0 +1,101 @@ +{ + "objects": [ + { + "basic": { + "body_point": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 67315, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628000325.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92912, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0825001943.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 91430, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825000856.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Outwear" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + ], + "process_id": "7296013643475027" +} \ No newline at end of file diff --git a/app/service/design/design_request_2.json b/app/service/design/design_request_2.json new file mode 100644 index 0000000..51b607a --- /dev/null +++ b/app/service/design/design_request_2.json @@ -0,0 +1,684 @@ +{ + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 67315, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628000325.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92912, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0825001943.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 91430, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825000856.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Outwear" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 92913, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/dress/826000033.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Dress" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 92914, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/0902001788.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Skirt" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92915, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902003817.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 92916, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/skirt_p4_838.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Skirt" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 84210, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0916000703.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 62041, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0902000232.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Outwear" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 67039, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902002591.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 78016, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/trousers_p4_302.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 92917, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0902001403.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92306, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902001766.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 86564, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0916000038.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92918, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628001561.jpeg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 92919, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/outwear_p3186.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Outwear" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + , + { + "basic": { + "body_point_test": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "151 78 78", + "icon": "none", + "image_id": 67009, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902002051.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Blouse" + }, + { + "color": "151 78 78", + "icon": "none", + "image_id": 85028, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/903000142.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Skirt" + }, + { + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "resize_scale": 1.0, + "type": "Body" + } + ] + } + ], + "process_id": "7296013643475027" +} \ No newline at end of file diff --git a/app/service/design/fastapi_request.json b/app/service/design/fastapi_request.json new file mode 100644 index 0000000..f578079 --- /dev/null +++ b/app/service/design/fastapi_request.json @@ -0,0 +1,69 @@ +{ + "basic": { + "body_point": { + "waistband_right": [ + 1081, + 1318 + ], + "hand_point_right": [ + 1200, + 1857 + ], + "waistband_left": [ + 639, + 1315 + ], + "hand_point_left": [ + 493, + 1808 + ], + "shoulder_left": [ + 556, + 582 + ], + "shoulder_right": [ + 1130, + 576 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "single", + "switch_category": "Trousers", + "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png" + }, + "item": [ + { + "color": "151 78 78", + "image_id": "67315", + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628000325.jpg", + "print": { + "if_single": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Trousers" + }, + { + "color": "151 78 78", + "path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", + "image_id": 69331, + "offset": [ + 1, + 1 + ], + "print": { + "if_single": false, + "print_path_list": [] + }, + "resize_scale": 1.0, + "type": "Body" + } + ] +} \ No newline at end of file diff --git a/app/service/design/items/__init__.py b/app/service/design/items/__init__.py new file mode 100644 index 0000000..23f35bf --- /dev/null +++ b/app/service/design/items/__init__.py @@ -0,0 +1,16 @@ +from .builder import ITEMS, build_item +from .clothing import Clothing # 4.0 sec +from .body import Body +from .top import Top, Blouse, Outwear, Dress +from .bottom import Bottom, Trousers, Skirt +from .shoes import Shoes +from .bag import Bag +from .accessories import Hairstyle, Earring + +__all__ = [ + 'ITEMS', 'build_item', + 'Clothing', 'Body', + 'Top', 'Blouse', 'Outwear', 'Dress', + 'Bottom', 'Trousers', 'Skirt', + 'Shoes', 'Bag', 'Hairstyle', 'Earring' +] diff --git a/app/service/design/items/accessories.py b/app/service/design/items/accessories.py new file mode 100644 index 0000000..5cb5796 --- /dev/null +++ b/app/service/design/items/accessories.py @@ -0,0 +1,59 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Hairstyle(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Hairstyle, self).__init__(**kwargs) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point): + """ + align up + Args: + keypoint_type: string, "head_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + side_indicator = f'{keypoint_type}_up' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + # logging.info(clothes_point[side_indicator]) + + start_point = ( + int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)), + int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale)) + ) + return start_point + + +@ITEMS.register_module() +class Earring(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Earring, self).__init__(**kwargs) diff --git a/app/service/design/items/bag.py b/app/service/design/items/bag.py new file mode 100644 index 0000000..c171e75 --- /dev/null +++ b/app/service/design/items/bag.py @@ -0,0 +1,44 @@ +from .builder import ITEMS +from .clothing import Clothing +import random + + +@ITEMS.register_module() +class Bag(Clothing): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Bag, self).__init__(**kwargs) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point): + """ + align left + Args: + keypoint_type: string, "hand_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (y', x') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + location = random.choice(seq=['left', 'right']) + if location == 'left': + side_indicator = f'{keypoint_type}_left' + else: + side_indicator = f'{keypoint_type}_right' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale), + body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale)) + return start_point diff --git a/app/service/design/items/body.py b/app/service/design/items/body.py new file mode 100644 index 0000000..69e8b36 --- /dev/null +++ b/app/service/design/items/body.py @@ -0,0 +1,35 @@ +import cv2 +from .builder import ITEMS +from .pipelines import Compose + + +@ITEMS.register_module() +class Body(object): + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']), + # dict(type='ImageShow', key=['body_image', "body_mask"]) + ] + self.pipeline = Compose(pipeline) + self.result = dict() + + def process(self): + self.pipeline(self.result) + pass + + def organize(self, layer): + body_layer = dict(priority=0, + name=type(self).__name__.lower(), + image=self.result['body_image'], + image_url=self.result['image_url'], + mask_image=None, + mask_url=None, + sacle=1, + # mask=self.result['body_mask'], + position=(0, 0)) + layer.insert(body_layer) + + @staticmethod + def show(img): + cv2.imshow('', img) + cv2.waitKey(0) diff --git a/app/service/design/items/bottom.py b/app/service/design/items/bottom.py new file mode 100644 index 0000000..eb575fb --- /dev/null +++ b/app/service/design/items/bottom.py @@ -0,0 +1,38 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Bottom(Clothing): + def __init__(self, pipeline, **kwargs): + if pipeline is None: + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting', painting_flag=True), + dict(type='PrintPainting', print_flag=True), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']), + ] + kwargs.update(pipeline=pipeline) + super(Bottom, self).__init__(**kwargs) + + +@ITEMS.register_module() +class Trousers(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Trousers, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Skirt(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Skirt, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Bottoms(Bottom): + def __init__(self, pipeline=None, **kwargs): + super(Bottoms, self).__init__(pipeline, **kwargs) diff --git a/app/service/design/items/builder.py b/app/service/design/items/builder.py new file mode 100644 index 0000000..26e04f1 --- /dev/null +++ b/app/service/design/items/builder.py @@ -0,0 +1,9 @@ +from mmcv.utils import Registry, build_from_cfg + +ITEMS = Registry('item') +PIPELINES = Registry('pipeline') + + +def build_item(cfg, default_args=None): + item = build_from_cfg(cfg, ITEMS, default_args) + return item diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py new file mode 100644 index 0000000..5adcc70 --- /dev/null +++ b/app/service/design/items/clothing.py @@ -0,0 +1,96 @@ +import cv2 + +from app.core.config import PRIORITY_DICT +from .builder import ITEMS +from .pipelines import Compose + + +@ITEMS.register_module() +class Clothing(object): + def __init__(self, pipeline, **kwargs): + self.pipeline = Compose(pipeline) + self.result = dict(name=type(self).__name__.lower(), **kwargs) + + def process(self): + self.pipeline(self.result) + + def apply_scale(self, img): + scale = self.result['scale'] + height, width = img.shape[0: 2] + if len(img.shape) > 2: + height, width = img.shape[0: 2] + scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA) + return scaled_img + + def organize(self, layer): + start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"]) + + front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None), + name=f'{type(self).__name__.lower()}_front', + image=self.result["front_image"], + # mask_image=self.result['front_mask_image'], + image_url=self.result['front_image_url'], + mask_url=self.result['front_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=start_point, + resize_scale=self.result["resize_scale"], + mask=cv2.resize(self.result['mask'], self.result["front_image"].size), + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + ) + layer.insert(front_layer) + + back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None), + name=f'{type(self).__name__.lower()}_back', + image=self.result["back_image"], + # mask_image=self.result['back_mask_image'], + image_url=self.result['back_image_url'], + mask_url=self.result['back_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=start_point, + resize_scale=self.result["resize_scale"], + mask=cv2.resize(self.result['mask'], self.result["front_image"].size), + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + ) + layer.insert(back_layer) + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): + """ + Align left + Args: + keypoint_type: string, "waistband" | "shoulder" | "ear_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + offset + y' = x_body - x1 * scale + offset + + """ + + side_indicator = f'{keypoint_type}_left' + + # if keypoint_type == "ear_point": + # start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale), + # body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale)) + # else: + # start_point = ( + # int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y + # int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x + # ) + + # milvus_DB_keypoint_cache: + start_point = ( + int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y + int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x + ) + # start_point = ( + # int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y + # int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x + # ) + + return start_point diff --git a/app/service/design/items/pipelines/__init__.py b/app/service/design/items/pipelines/__init__.py new file mode 100644 index 0000000..9abb09c --- /dev/null +++ b/app/service/design/items/pipelines/__init__.py @@ -0,0 +1,19 @@ +from .compose import Compose +from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow +from .keypoints import KeypointDetection +from .segmentation import Segmentation +from .painting import Painting, PrintPainting +from .scale import Scaling +from .contour_detection import ContourDetection +from .split import Split + +__all__ = [ + 'Compose', + 'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow', + 'KeypointDetection', + 'Segmentation', + 'Painting', 'PrintPainting', + 'Scaling', + 'ContourDetection', + 'split', +] diff --git a/app/service/design/items/pipelines/compose.py b/app/service/design/items/pipelines/compose.py new file mode 100644 index 0000000..daf6977 --- /dev/null +++ b/app/service/design/items/pipelines/compose.py @@ -0,0 +1,36 @@ +import collections + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data diff --git a/app/service/design/items/pipelines/contour_detection.py b/app/service/design/items/pipelines/contour_detection.py new file mode 100644 index 0000000..df6c7b2 --- /dev/null +++ b/app/service/design/items/pipelines/contour_detection.py @@ -0,0 +1,59 @@ +import logging + +from ..builder import PIPELINES +import cv2 +import numpy as np + + +@PIPELINES.register_module() +class ContourDetection(object): + def __init__(self): + # logging.info("ContourDetection run ") + pass + + #@ RunTime + def __call__(self, result): + # shoe diff + if result['name'] == 'shoes': + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + for i in range(2): + Max_contour = Contour[i] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + else: + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + if len(Contour): + Max_contour = Contour[0] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + else: + Mask = np.ones(result['image'].shape[:2], np.uint8) * 255 + # TODO 修复部分图片出现透明的情况 下版本上线 + # img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + # ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY) + # Mask = cv2.bitwise_not(Mask) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + + return result + + @staticmethod + def get_contours(image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + Edge = cv2.Canny(gray, 10, 150) + kernel = np.ones((5, 5), np.uint8) + Edge = cv2.dilate(Edge, kernel=kernel, iterations=1) + Edge = cv2.erode(Edge, kernel=kernel, iterations=1) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contour = sorted(Contour, key=cv2.contourArea, reverse=True) + return Contour diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py new file mode 100644 index 0000000..fc59b61 --- /dev/null +++ b/app/service/design/items/pipelines/keypoints.py @@ -0,0 +1,148 @@ +import logging +import time +import numpy as np +from pymilvus import MilvusClient + +from app.core.config import * +from ..builder import PIPELINES +from ...utils.design_ensemble import get_keypoint_result + + +@PIPELINES.register_module() +class KeypointDetection(object): + """ + path here: abstract path + """ + + def __init__(self): + self.client = MilvusClient( + uri="http://10.1.1.240:19530", + token="root:Milvus", + db_name=MILVUS_ALIAS + ) + + def __del__(self): + # start_time = time.time() + self.client.close() + # print(f"client close time : {time.time() - start_time}") + + # @ RunTime + def __call__(self, result): + # logging.info("KeypointDetection run ") + if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 + # result['clothes_keypoint'] = self.infer_keypoint_result(result) + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # keypoint_cache = search_keypoint_cache(result["image_id"], site) + + # 取消向量查询 直接过模型推理 + keypoint_cache = self.keypoint_cache(result, site) + # keypoint_cache = False + + if keypoint_cache is False: + keypoint_infer_result, site = self.infer_keypoint_result(result) + result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) + else: + result['clothes_keypoint'] = keypoint_cache + return result + + @staticmethod + def infer_keypoint_result(result): + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + start_time = time.time() + keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果 + # logging.info(f"infer keypoint time : {time.time() - start_time}") + return keypoint_infer_result, site + + @staticmethod + # @ RunTime + def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None): + if site == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, cache.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([cache.flatten(), zeros]) + # 取消向量保存 直接拿结果 + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": site, + "keypoint_vector": result.tolist() + } + ] + client = MilvusClient( + uri="http://10.1.1.240:19530", + token="root:Milvus", + db_name=MILVUS_ALIAS + ) + try: + start_time = time.time() + res = client.upsert( + collection_name=MILVUS_TABLE_KEYPOINT, + data=data, + ) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + finally: + client.close() + + @staticmethod + def update_keypoint_cache(keypoint_id, infer_result, search_result, site): + if site == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": "all", + "keypoint_vector": result.tolist() + } + ] + client = MilvusClient( + uri="http://10.1.1.240:19530", + token="root:Milvus", + db_name=MILVUS_ALIAS + ) + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.upsert(data) + client.upsert( + collection_name=MILVUS_TABLE_KEYPOINT, + data=data + ) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + # @ RunTime + def keypoint_cache(self, result, site): + try: + keypoint_id = result['image_id'] + res = self.client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != site: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) + except Exception as e: + logging.info(f"search keypoint cache milvus error {e}") + return False diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py new file mode 100644 index 0000000..2697006 --- /dev/null +++ b/app/service/design/items/pipelines/loading.py @@ -0,0 +1,143 @@ +import io +import logging +import time + +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.core.config import * +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + def __init__(self, path, color=None, print_dict=None): + self.path = path + self.color = color + self.print_dict = print_dict + self.minio_client = Minio( + f"{MINIO_URL}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + def __call__(self, result): + result['image'], result['pre_mask'] = self.read_image(self.path) + result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + result['keypoint'] = self.get_keypoint(result['name']) + result['path'] = self.path + result['img_shape'] = result['image'].shape + result['ori_shape'] = result['image'].shape + result['color'] = self.color if self.color is not None else None + result['print_dict'] = self.print_dict + return result + + @staticmethod + def get_keypoint(name): + if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': + keypoint = 'shoulder' + elif name == 'trousers' or name == 'skirt' or name == 'bottoms': + keypoint = 'waistband' + elif name == 'bag': + keypoint = 'hand_point' + elif name == 'shoes': + keypoint = 'toe' + elif name == 'hairstyle': + keypoint = 'head_point' + elif name == 'earring': + keypoint = 'ear_point' + else: + raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " + f"bag, shoes, hairstyle, earring.") + return keypoint + + def read_image(self, image_path): + image_mask = None + file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data + image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + if image.shape[2] == 4: # 如果是四通道 mask + image_mask = image[:, :, 3] + image = image[:, :, :3] + return image, image_mask + + +@PIPELINES.register_module() +class LoadBodyImageFromFile(object): + def __init__(self, body_path): + self.body_path = body_path + self.minioClient = Minio( + f"{MINIO_URL}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + # response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png") + + # @ RunTime + def __call__(self, result): + result["image_url"] = result['body_path'] = self.body_path + result["name"] = "mannequin" + if not result['image_url'].lower().endswith(".png"): + logging.info(1) + bucket = self.body_path.split("/", 1)[0] + object_name = self.body_path.split("/", 1)[1] + new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + image = self.minioClient.get_object(bucket, object_name) + image = Image.open(io.BytesIO(image.data)) + image = image.convert("RGBA") + data = image.getdata() + # + new_data = [] + for item in data: + if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + new_data.append((255, 255, 255, 0)) + else: + new_data.append(item) + image.putdata(new_data) + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + self.body_path = image_path + result["image_url"] = result['body_path'] = self.body_path + response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) + # put_image_time = time.time() + result['body_image'] = Image.open(io.BytesIO(response.read())) + # logging.info(f"Image.open time is : {time.time() - put_image_time}") + return result + + +@PIPELINES.register_module() +class ImageShow(object): + def __init__(self, key): + self.key = key + + # @ RunTime + def __call__(self, result): + import matplotlib.pyplot as plt + if isinstance(self.key, list): + for key in self.key: + plt.imshow(result[key]) + plt.title(key) + plt.show() + elif isinstance(self.key, str): + img = self._resize_img(result[self.key]) + cv2.imshow(self.key, img) + cv2.waitKey(0) + else: + raise TypeError(f'key should be string but got type {type(self.key)}.') + return result + + @staticmethod + def _resize_img(img): + shape = img.shape + if shape[0] > 400 or shape[1] > 400: + ratio = min(400 / shape[0], 400 / shape[1]) + img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0]))) + return img diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py new file mode 100644 index 0000000..b1d1ea7 --- /dev/null +++ b/app/service/design/items/pipelines/painting.py @@ -0,0 +1,498 @@ +import random +from io import BytesIO +import boto3 +import cv2 +import numpy as np +from PIL import Image +from ..builder import PIPELINES + +# minio_client = Minio( +# f"{MINIO_IP}:{MINIO_PORT}", +# access_key=MINIO_ACCESS, +# secret_key=MINIO_SECRET, +# secure=MINIO_SECURE) +s3 = boto3.client( + 's3', + aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", + aws_secret_access_key="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", + region_name="ap-east-1" +) + + +@PIPELINES.register_module() +class Painting(object): + def __init__(self, painting_flag=True): + self.painting_flag = painting_flag + + # @ RunTime + def __call__(self, result): + if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none': + dim_image_h, dim_image_w = result['image'].shape[0:2] + if "gradient" in result.keys() and result['gradient'] != "": + bucket_name = result['gradient'].split('/')[0] + object_name = result['gradient'][result['gradient'].find('/') + 1:] + pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + else: + pattern = self.get_pattern(result['color']) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + result['alpha'] = 100 / 255.0 + else: + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + get_image_fir = result['image'] * (closed_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + return result + + @staticmethod + def get_gradient(bucket_name, object_name): + # image_data = minio_client.get_object(bucket_name, object_name) + image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] + + # 从数据流中读取图像 + image_bytes = image_data.read() + + # 将图像数据转换为numpy数组 + image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) + + # 使用OpenCV解码图像数组 + image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + return image + + @staticmethod + def crop_image(image, image_size_h, image_size_w): + x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6) + y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6) + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_pattern(single_color): + if single_color is None: + raise False + R, G, B = single_color.split(' ') + pattern = np.zeros([1, 1, 3], np.uint8) + pattern[0, 0, 0] = int(B) + pattern[0, 0, 1] = int(G) + pattern[0, 0, 2] = int(R) + return pattern + + @staticmethod + def gradient(image, angle_degrees, start_color, end_color): + height, width = image.shape[0], image.shape[1] + + # 创建一个空白的图像 + gradient_image = np.zeros((height, width, 3), dtype=np.uint8) + + # 将角度限制在 0 到 360 度之间 + angle_degrees = np.clip(angle_degrees, 0, 360) + + # 将角度转换为弧度 + angle_radians = np.radians(angle_degrees) + + # 计算渐变的方向 + dx = np.cos(angle_radians) + dy = np.sin(angle_radians) + + # 创建网格 + x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height)) + + # 计算每个像素在渐变方向上的位置 + distance_along_gradient = (x_grid * dx + y_grid * dy) / np.sqrt(dx ** 2 + dy ** 2) + + # 计算渐变的权重 + weight = np.clip(distance_along_gradient / max(width, height), 0, 1) + + # 计算渐变的颜色 + gradient_image[:, :, 0] = (1 - weight) * start_color[0] + weight * end_color[0] + gradient_image[:, :, 1] = (1 - weight) * start_color[1] + weight * end_color[1] + gradient_image[:, :, 2] = (1 - weight) * start_color[2] + weight * end_color[2] + + return gradient_image + + +@PIPELINES.register_module() +class PrintPainting(object): + def __init__(self, print_flag=True): + self.print_flag = print_flag + + # @ RunTime + def __call__(self, result): + + if "location" not in result['print'].keys(): + result['print']["location"] = [[0, 0]] + elif result['print']["location"] == [] or result['print']["location"] is None: + result['print']["location"] = [[0, 0]] + if result['print']['IfSingle']: + if len(result['print']['print_path_list']) == 0: + raise ValueError('When there is no printing, ifsingle must be false') + + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + # print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8) + for i in range(len(result['print']['print_path_list'])): + image, image_mode = self.read_image(result['print']['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i]) + + source_image_pil = Image.fromarray(print_background) + source_image_pil_mask = Image.fromarray(mask_background) + + source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) + + print_background = np.array(source_image_pil) + mask_background = np.array(source_image_pil_mask) + + print(1) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result + else: + painting_dict = {} + painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] + + # no print + if len(result['print_dict']['print_path_list']) == 0 or not self.print_flag: + result['print_image'] = result['pattern_image'] + # print + else: + painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['final_image'] = result['print_image'] + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result + + @staticmethod + def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): + temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8) + + temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + + img2gray = cv2.cvtColor(print_background, cv2.COLOR_BGR2GRAY) + + ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY) + + mask_inv = cv2.bitwise_not(mask_) + + img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_) + + img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_inv) + + print_background = img1_bg + img2_fg + + return print_background + + def painting_collection(self, painting_dict, result, print_trigger=False): + if print_trigger: + print_ = self.get_print(result['print_dict']) + painting_dict['Trigger'] = not print_['IfSingle'] + painting_dict['location'] = print_['location'] if 'location' in print_.keys() else None + single_mask_inv_print = self.get_mask_inv(print_['image']) + dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) + dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) + if not print_['IfSingle']: + self.random_seed = random.randint(0, 1000) + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern + return painting_dict + + def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + if not trigger: + tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + else: + resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + if len(pattern.shape) == 2: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4)) + if len(pattern.shape) == 3: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1)) + tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape) + return tile + + def get_mask_inv(self, print_): + if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255: + bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + mask_inv = cv2.inRange(print_tile, lower, upper) + return mask_inv + else: + # bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + # bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + # bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + # bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + # lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + # upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY) + + # mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY) + mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8) + return mask_inv + + @staticmethod + def printpaint(result, painting_dict, print_=False): + + if print_ and painting_dict['Trigger']: + print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) + img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) + else: + print_mask = result['mask'] + img_fg = result['final_image'] + if print_ and not painting_dict['Trigger']: + try: + index_ = len(painting_dict['location']) + except: + assert f'there must be parameter of location if choose IfSingle' + + for i in range(index_): + start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) + + length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) + length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) + + change_region = img_fg[start_h: length_h, start_w: length_w, :] + # problem in change_mask + change_mask = print_mask[start_h: length_h, start_w: length_w] + # get real part into change mask + _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) + mask = cv2.bitwise_not(painting_dict['mask_inv_print']) + img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region + + clothes_mask_print = cv2.bitwise_not(print_mask) + + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + print_image = cv2.add(img_bg, img_fg) + return print_image + + @staticmethod + def get_print(print_dict): + if not 'print_scale_list' in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + print_dict['scale'] = 0.3 + else: + print_dict['scale'] = print_dict['print_scale_list'][0] + + if not 'IfSingle' in print_dict.keys(): + print_dict['IfSingle'] = False + + # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) + data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] + + data_bytes = BytesIO(data.read()) + image = Image.open(data_bytes) + image_mode = image.mode + # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + if image_mode == "RGBA": + new_background = Image.new('RGB', image.size, (255, 255, 255)) + new_background.paste(image, mask=image.split()[3]) + image = new_background + print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + + # file = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]).data + # print_dict['image'] = cv2.imdecode(np.fromstring(file, np.uint8), 1) + + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + # return image + + return print_dict + + def crop_image(self, image, image_size_h, image_size_w, location, print_shape): + print_w = print_shape[1] + print_h = print_shape[0] + + random.seed(self.random_seed) + # logging.info(f'overall print location : {location}') + # x_offset = random.randint(0, image.shape[0] - image_size_h) + # y_offset = random.randint(0, image.shape[1] - image_size_w) + + # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 + x_offset = print_w - int(location[0][1] % print_w) + y_offset = print_w - int(location[0][0] % print_h) + + # y_offset = int(location[0][0]) + # x_offset = int(location[0][1]) + + if len(image.shape) == 2: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w] + elif len(image.shape) == 3: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_low_high_lab(Lab_value, L=False): + if L: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + else: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + return high, low + + @staticmethod + def img_rotate(image, angel, scale): + """顺时针旋转图像任意角度 + + Args: + image (np.array): [原始图像] + angel (float): [逆时针旋转的角度] + + Returns: + [array]: [旋转后的图像] + """ + + h, w = image.shape[:2] + center = (w // 2, h // 2) + # if type(angel) is not int: + # angel = 0 + M = cv2.getRotationMatrix2D(center, -angel, scale) + # 调整旋转后的图像长宽 + rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0])))) + rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0])))) + M[0, 2] += (rotated_w - w) // 2 + M[1, 2] += (rotated_h - h) // 2 + # 旋转图像 + rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h)) + + return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) + # return rotated_img, (0, 0) + + @staticmethod + def read_image(image_url): + # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) + data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] + + data_bytes = BytesIO(data.read()) + image = Image.open(data_bytes) + image_mode = image.mode + # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + if image_mode == "RGBA": + # new_background = Image.new('RGB', image.size, (255, 255, 255)) + # new_background.paste(image, mask=image.split()[3]) + # image = new_background + return image, image_mode + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + return image, image_mode + + # @staticmethod + # def read_image(image_url): + # response = requests.get(image_url) + # image_data = np.frombuffer(response.content, np.uint8) + # + # # 解码图像 + # image = cv2.imdecode(image_data, 3) + # return image diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py new file mode 100644 index 0000000..6e0cf87 --- /dev/null +++ b/app/service/design/items/pipelines/scale.py @@ -0,0 +1,54 @@ +from ..builder import PIPELINES +import math +import cv2 + + +@PIPELINES.register_module() +class Scaling(object): + def __init__(self): + pass + + # @ RunTime + def __call__(self, result): + if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: + # milvus_db_keypoint_cache + distance_clo = math.sqrt( + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2 + + + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2) + + distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) + # distance_clo = math.sqrt( + # (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2 + # + + # (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2) + # + # distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) + if distance_clo == 0: + result['scale'] = 10 + else: + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'toe': + distance_bdy = math.sqrt( + (int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2 + + + (int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2 + ) + + Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0) + Edge = cv2.Canny(Blur, 10, 200) + Edge = cv2.dilate(Edge, None) + Edge = cv2.erode(Edge, None) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contours = sorted(Contour, key=cv2.contourArea, reverse=True) + + Max_contour = Contours[0] + x, y, w, h = cv2.boundingRect(Max_contour) + width = w + distance_clo = width + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'hand_point': + result['scale'] = result['scale_bag'] + elif result['keypoint'] == 'ear_point': + result['scale'] = result['scale_earrings'] + return result diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py new file mode 100644 index 0000000..d9f8ac0 --- /dev/null +++ b/app/service/design/items/pipelines/segmentation.py @@ -0,0 +1,14 @@ +from ..builder import PIPELINES +from ...utils.design_ensemble import get_seg_result + + +@PIPELINES.register_module() +class Segmentation(object): + def __init__(self, device='cpu', show=False, debug=None): + self.show = show + self.device = device + self.debug = debug + + def __call__(self, result): + result['seg_result'] = get_seg_result(result["image_id"], result['image']) + return result diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py new file mode 100644 index 0000000..d800597 --- /dev/null +++ b/app/service/design/items/pipelines/split.py @@ -0,0 +1,115 @@ +import logging +import cv2 +import numpy as np +from cv2 import cvtColor, COLOR_BGR2RGBA +from app.service.utils.generate_uuid import generate_uuid +from ..builder import PIPELINES +from PIL import Image +from ...utils.conversion_image import rgb_to_rgba +from ...utils.upload_image import upload_png_mask + + +@PIPELINES.register_module() +class Split(object): + """ + Split image into front and back layer according to the segmentation result + """ + + # KNet + def __call__(self, result): + try: + if 'mask' not in result.keys(): + raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') + if 'seg_result' not in result.keys(): # 没过seg模型 + result['front_mask'] = result['mask'].copy() + result['back_mask'] = np.zeros_like(result['mask']) + else: + temp_front = result['seg_result'] == 1 + result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) + temp_back = result['seg_result'] == 2 + result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) + + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + if len(result['front_mask'].shape) > 2: + front_mask = result['front_mask'][0] + else: + front_mask = result['front_mask'] + + if len(result['back_mask'].shape) > 2: + back_mask = result['back_mask'][0] + else: + back_mask = result['back_mask'] + + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + result_front_image = np.zeros_like(rgba_image) + result_front_image[front_mask != 0] = rgba_image[front_mask != 0] + + result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"])) + result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) + front_mask = cv2.resize(front_mask, front_new_size) + result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) + + if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): + result_back_image = np.zeros_like(rgba_image) + result_back_image[back_mask != 0] = rgba_image[back_mask != 0] + + result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) + back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"])) + result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) + back_mask = cv2.resize(back_mask, back_new_size) + result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) + else: + result['back_image'] = None + result["back_image_url"] = None + result["back_mask_url"] = None + return result + except Exception as e: + logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") + + # @ RunTime + # def __call__(self, result): + # try: + # if 'mask' not in result.keys(): + # raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') + # if 'seg_result' not in result.keys(): # 没过seg模型 + # result['front_mask'] = result['mask'].copy() + # result['back_mask'] = np.zeros_like(result['mask']) + # else: + # temp_front = result['seg_result'] == 1 + # result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) + # temp_back = result['seg_result'] == 2 + # result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) + # + # if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + # if len(result['front_mask'].shape) > 2: + # front_mask = result['front_mask'][0] + # else: + # front_mask = result['front_mask'] + # + # rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + # result_front_image = np.zeros_like(rgba_image) + # result_front_image[front_mask != 0] = rgba_image[front_mask != 0] + # + # result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + # front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"])) + # result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) + # front_mask = cv2.resize(front_mask, front_new_size) + # result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) + # + # if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): + # result_back_image = np.zeros_like(rgba_image) + # result_back_image[result['back_mask'] != 0] = rgba_image[result['back_mask'] != 0] + # + # result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) + # back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"])) + # result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) + # back_mask = cv2.resize(result['back_mask'], back_new_size) + # result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) + # else: + # result['back_image'] = None + # result["back_image_url"] = None + # result["back_mask_url"] = None + # return result + # except Exception as e: + # logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/items/shoes.py b/app/service/design/items/shoes.py new file mode 100644 index 0000000..f4e17f2 --- /dev/null +++ b/app/service/design/items/shoes.py @@ -0,0 +1,126 @@ +import io +import logging +import time + +import cv2 +import numpy as np + +from .builder import ITEMS +from .clothing import Clothing +from PIL import Image + +from ..utils.conversion_image import rgb_to_rgba +from ..utils.upload_image import upload_png_mask +from ...utils.generate_uuid import generate_uuid + + +@ITEMS.register_module() +class Shoes(Clothing): + # TODO location of shoes has little mismatch + def __init__(self, **kwargs): + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Painting'), + dict(type='Scaling'), + dict(type='Split'), + # dict(type='ImageShow', key=['image', 'mask', 'pattern_image']), + ] + kwargs.update(pipeline=pipeline) + super(Shoes, self).__init__(**kwargs) + + def organize(self, layer): + left_shoe_mask, right_shoe_mask = self.cut() + + left_layer = dict(name=f'{type(self).__name__.lower()}_left', + image=self.result['shoes_left'], + image_url=self.result['left_image_url'], + mask_url=self.result['left_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=self.calculate_start_point(self.result['keypoint'], + self.result['scale'], + self.result['clothes_keypoint'], + self.result['body_point'], + 'left')) + layer.insert(left_layer) + + right_layer = dict(name=f'{type(self).__name__.lower()}_right', + image=self.result['shoes_right'], + image_url=self.result['right_image_url'], + mask_url=self.result['right_mask_url'], + sacle=self.result['scale'], + clothes_keypoint=self.result['clothes_keypoint'], + position=self.calculate_start_point(self.result['keypoint'], + self.result['scale'], + self.result['clothes_keypoint'], + self.result['body_point'], + 'right')) + + layer.insert(right_layer) + + def cut(self): + """ + Cut shoes mask into two pieces + Returns: + """ + contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = sorted(contour, key=cv2.contourArea, reverse=True) + + bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]] + (contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False)) + + epsilon_left = 0.001 * cv2.arcLength(contours[0], True) + + approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True) + mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8) + cv2.drawContours(mask_left, [approx_left], -1, 255, -1) + item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0) + + rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left) + result_image = np.zeros_like(rgba_image) + result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0] + result_left_image_pil = Image.fromarray(result_image, 'RGBA') + result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS) + self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}") + + epsilon_right = 0.001 * cv2.arcLength(contours[1], True) + approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True) + mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8) + cv2.drawContours(mask_right, [approx_right], -1, 255, -1) + item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0) + + rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right) + result_image = np.zeros_like(rgba_image) + result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0] + result_right_image_pil = Image.fromarray(result_image, 'RGBA') + result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS) + self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}") + + return item_mask_left, item_mask_right + + @staticmethod + def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location): + """ + left shoes align left + right shoes align right + Args: + keypoint_type: string, "toe" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + location: string, indicates whether the start point belongs to right or left shoe + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + y' = x_body - x1 * scale + """ + if location not in ['left', 'right']: + raise KeyError(f'location value must be left or right but got {location}') + side_indicator = f'{keypoint_type}_{location}' + # clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()} + start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale), + body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale)) + return start_point diff --git a/app/service/design/items/top.py b/app/service/design/items/top.py new file mode 100644 index 0000000..135328f --- /dev/null +++ b/app/service/design/items/top.py @@ -0,0 +1,46 @@ +from .builder import ITEMS +from .clothing import Clothing + + +@ITEMS.register_module() +class Top(Clothing): + def __init__(self, pipeline, **kwargs): + if pipeline is None: + pipeline = [ + dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), + dict(type='KeypointDetection'), + dict(type='ContourDetection'), + dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']), + dict(type='Painting', painting_flag=True), + dict(type='PrintPainting', print_flag=True), + # dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']), + dict(type='Scaling'), + dict(type='Split'), + ] + kwargs.update(pipeline=pipeline) + super(Top, self).__init__(**kwargs) + + +@ITEMS.register_module() +class Blouse(Top): + def __init__(self, pipeline=None, **kwargs): + super(Blouse, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Outwear(Top): + def __init__(self, pipeline=None, **kwargs): + super(Outwear, self).__init__(pipeline, **kwargs) + + +@ITEMS.register_module() +class Dress(Top): + def __init__(self, pipeline=None, **kwargs): + super(Dress, self).__init__(pipeline, **kwargs) + + +# Men's clothing +@ITEMS.register_module() +class Tops(Top): + def __init__(self, pipeline=None, **kwargs): + super(Tops, self).__init__(pipeline, **kwargs) diff --git a/app/service/design/service.py b/app/service/design/service.py new file mode 100644 index 0000000..372456f --- /dev/null +++ b/app/service/design/service.py @@ -0,0 +1,130 @@ +from app.core.config import PRIORITY_DICT +from app.service.design.core.layer import Layer +from app.service.design.items import build_item +from app.service.design.utils.redis_utils import Redis +from app.service.design.utils.synthesis_item import synthesis, synthesis_single +import concurrent.futures + + +def process_item(item, layers): + # logging.info("process running.........") + item.process() + item.organize(layers) + if item.result['name'] == "mannequin": + return item.result['body_image'].size + + +def update_progress(process_id, total): + r = Redis() + progress = r.read(key=process_id) + if progress and total != 1: + if int(progress) <= 100: + r.write(key=process_id, value=int(progress) + int(100 / total)) + else: + r.write(key=process_id, value=100) + return progress + elif total == 1: + r.write(key=process_id, value=100) + return progress + else: + r.write(key=process_id, value=int(100 / total)) + return progress + + +def final_progress(process_id): + r = Redis() + progress = r.read(key=process_id) + r.write(key=process_id, value=100) + return progress + + +def generate(request_data): + return_response = {} + request_data = request_data.dict() + assert "process_id" in request_data.keys(), "Need process_id parameters" + + objects = request_data['objects'] + # insert_keypoint_cache(objects) + process_id = request_data['process_id'] + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交每个对象的处理任务 + futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)} + # 获取处理结果 + for future in concurrent.futures.as_completed(futures): + obj = futures[future] + + result = future.result() + return_response[obj] = result + final_progress(process_id) + return return_response + + +def process_object(cfg, process_id, total): + basic_info = cfg.get('basic') + items_response = { + 'layers': [] + } + if cfg.get('basic')['single_overall'] == 'overall': + basic_info['debug'] = False + items = [build_item(x, default_args=basic_info) for x in cfg.get('items')] + layers = Layer() + body_size = None + futures = [] + for item in items: + futures = [process_item(item, layers)] + for future in futures: + if future is not None: + body_size = future + # 是否自定义排序 + if basic_info.get('layer_order', False): + layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf'))) + else: + layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf'))) + # 合成 + items_response['synthesis_url'] = synthesis(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + + # 'image': lay['image'], + # 'mask_image': lay['mask_image'], + }) + elif cfg.get('basic')['single_overall'] == 'single': + assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters " + basic_info['debug'] = False + for item in cfg.get('items'): + if item['type'] == cfg.get('basic')['switch_category']: + item = build_item(item, default_args=cfg.get('basic')) + item.process() + items_response['layers'].append({ + 'image_category': f"{item.result['name']}_front", + 'image_size': item.result['back_image'].size if item.result['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item.result['front_image_url'], + 'mask_url': item.result['front_mask_url'], + "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "" + + }) + items_response['layers'].append({ + 'image_category': f"{item.result['name']}_back", + 'image_size': item.result['front_image'].size if item.result['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item.result['back_image_url'], + 'mask_url': item.result['back_mask_url'], + "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "" + + }) + items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image']) + break + update_progress(process_id, total) + return items_response diff --git a/app/service/design/utils/__init__.py b/app/service/design/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/design/utils/conversion_image.py b/app/service/design/utils/conversion_image.py new file mode 100644 index 0000000..0915070 --- /dev/null +++ b/app/service/design/utils/conversion_image.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :conversion_image.py +@Author :周成融 +@Date :2023/8/21 10:40:29 +@detail : +""" +import numpy as np + + +def rgb_to_rgba(rgb_size, rgb_image, mask): + alpha_channel = np.full(rgb_size, 255, dtype=np.uint8) + # 创建四通道的结果图像 + rgba_image = np.dstack((rgb_image, alpha_channel)) + alpha_channel = np.where(mask > 0, 255, 0) + # 更新RGBA图像的透明度通道 + rgba_image[:, :, 3] = alpha_channel + return rgba_image + +if __name__ == '__main__': + image = open("") \ No newline at end of file diff --git a/app/service/design/utils/design_ensemble.py b/app/service/design/utils/design_ensemble.py new file mode 100644 index 0000000..e1df56a --- /dev/null +++ b/app/service/design/utils/design_ensemble.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :design_ensemble.py +@Author :周成融 +@Date :2023/8/16 19:36:21 +@detail :发起请求 获取推理结果 +""" +import logging +import cv2 +import mmcv +import numpy as np +import tritonclient.http as httpclient +import torch +import torch.nn.functional as F +from app.core.config import * + +""" + keypoint + 预处理 推理 后处理 +""" + + +def keypoint_preprocess(img_path): + img = mmcv.imread(img_path) + img_scale = (256, 256) + img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, (w_scale, h_scale) + + +# @ RunTime +# 推理 +def get_keypoint_result(image, site): + keypoint_result = None + try: + image, scale_factor = keypoint_preprocess(image) + client = httpclient.InferenceServerClient(url=KEYPOINT_MODEL_URL) + transformed_img = image.astype(np.float32) + inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)] + results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs) + inference_output = torch.from_numpy(results.as_numpy(f'output')) + keypoint_result = keypoint_postprocess(inference_output, scale_factor) + except Exception as e: + logging.warning(f"get_keypoint_result : {e}") + return keypoint_result + + +def keypoint_postprocess(output, scale_factor): + max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2) + max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2) + segment_result = max_coords.numpy() + scale_factor = [1 / x for x in scale_factor[::-1]] + scale_matrix = np.diag(scale_factor) + nan = np.isinf(scale_matrix) + scale_matrix[nan] = 0 + return np.ceil(np.dot(segment_result, scale_matrix) * 4) + + +""" + seg + 预处理 推理 后处理 +""" + + +# KNet +def seg_preprocess(img_path): + img = mmcv.imread(img_path) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + scale_factor = [] + img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True) + scale_factor.append(x) + scale_factor.append(y) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + +# @ RunTime +def get_seg_result(image_id, image): + image, ori_shape = seg_preprocess(image) + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") + transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), + ] + results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = results.as_numpy(SEGMENTATION['output']) + seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape) + return seg_result + + +# no cache +def seg_postprocess(image_id, output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +def key_point_show(image_path, key_point_result=None): + img = cv2.imread(image_path) + points_list = key_point_result + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for point in points_list: + cv2.circle(img, point[::-1], point_size, point_color, thickness) + cv2.imshow("0", img) + cv2.waitKey(0) + + +if __name__ == '__main__': + image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif") + a = get_keypoint_result(image, "up") + new_list = [] + print(list) + for i in a[0]: + new_list.append((int(i[0]), int(i[1]))) + key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list) + # a = get_seg_result(1, image) + print(a) diff --git a/app/service/design/utils/redis_utils.py b/app/service/design/utils/redis_utils.py new file mode 100644 index 0000000..012fbe0 --- /dev/null +++ b/app/service/design/utils/redis_utils.py @@ -0,0 +1,99 @@ +import redis + +from app.core.config import REDIS_HOST, REDIS_PORT + + +class Redis(object): + """ + redis数据库操作 + """ + + @staticmethod + def _get_r(): + host = REDIS_HOST + port = REDIS_PORT + db = 0 + r = redis.StrictRedis(host, port, db) + return r + + @classmethod + def write(cls, key, value, expire=None): + """ + 写入键值对 + """ + # 判断是否有过期时间,没有就设置默认值 + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.set(key, value, ex=expire_in_seconds) + + @classmethod + def read(cls, key): + """ + 读取键值对内容 + """ + r = cls._get_r() + value = r.get(key) + return value.decode('utf-8') if value else value + + @classmethod + def hset(cls, name, key, value): + """ + 写入hash表 + """ + r = cls._get_r() + r.hset(name, key, value) + + @classmethod + def hget(cls, name, key): + """ + 读取指定hash表的键值 + """ + r = cls._get_r() + value = r.hget(name, key) + return value.decode('utf-8') if value else value + + @classmethod + def hgetall(cls, name): + """ + 获取指定hash表所有的值 + """ + r = cls._get_r() + return r.hgetall(name) + + @classmethod + def delete(cls, *names): + """ + 删除一个或者多个 + """ + r = cls._get_r() + r.delete(*names) + + @classmethod + def hdel(cls, name, key): + """ + 删除指定hash表的键值 + """ + r = cls._get_r() + r.hdel(name, key) + + @classmethod + def expire(cls, name, expire=None): + """ + 设置过期时间 + """ + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.expire(name, expire_in_seconds) + + +if __name__ == '__main__': + redis_client = Redis() + # print(redis_client.write(key="1230", value=0)) + redis_client.write(key="1230", value=10) + # print(redis_client.read(key="1230")) diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py new file mode 100644 index 0000000..8792f7b --- /dev/null +++ b/app/service/design/utils/synthesis_item.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :synthesis_item.py +@Author :周成融 +@Date :2023/8/26 14:13:04 +@detail : +""" +import io +import logging +import time + +import boto3 +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.service.utils.decorator import RunTime +from app.service.utils.generate_uuid import generate_uuid + +# minio_client = Minio( +# f"{MINIO_IP}:{MINIO_PORT}", +# access_key=MINIO_ACCESS, +# secret_key=MINIO_SECRET, +# secure=MINIO_SECURE) + +s3 = boto3.client( + 's3', + aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", + aws_secret_access_key="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", + region_name="ap-east-1" +) + + +def positioning(all_mask_shape, mask_shape, offset): + all_start = 0 + all_end = 0 + mask_start = 0 + mask_end = 0 + if offset == 0: + all_start = 0 + all_end = min(all_mask_shape, mask_shape) + + mask_start = 0 + mask_end = min(all_mask_shape, mask_shape) + elif offset > 0: + all_start = min(offset, all_mask_shape) + all_end = min(offset + mask_shape, all_mask_shape) + + mask_start = 0 + mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape) + elif offset < 0: + if abs(offset) > mask_shape: + all_start = 0 + all_end = 0 + else: + all_start = 0 + if mask_shape - abs(offset) > all_mask_shape: + all_end = min(mask_shape - abs(offset), all_mask_shape) + else: + all_end = mask_shape - abs(offset) + + if abs(offset) > mask_shape: + mask_start = mask_shape + mask_end = mask_shape + else: + mask_start = abs(offset) + if mask_shape - abs(offset) >= all_mask_shape: + mask_end = all_mask_shape + abs(offset) + else: + mask_end = mask_shape + return all_start, all_end, mask_start, mask_end + + +@RunTime +def synthesis(data, size): + # 创建底图 + base_image = Image.new('RGBA', size, (0, 0, 0, 0)) + try: + + all_mask_shape = (size[1], size[0]) + top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) + bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) + + top = True + bottom = True + i = len(data) + while i: + i -= 1 + if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]: + top = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] + elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]: + bottom = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] + elif bottom is False and top is False: + break + + all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + + for layer in data: + if layer['image'] is not None: + if layer['name'] != "body": + test_image = Image.new('RGBA', size, (0, 0, 0, 0)) + test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) + mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) + mask_alpha = Image.fromarray(mask_data) + cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) + base_image.paste(cropped_image, (0, 0), cropped_image) + else: + base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) + + result_image = base_image + + with io.BytesIO() as output: + result_image.save(output, format='PNG') + data = output.getvalue() + + # image_data = io.BytesIO() + # result_image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + object_name = f'result_{generate_uuid()}.png' + response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + object_url = f"aida-results/{object_name}" + if response['ResponseMetadata']['HTTPStatusCode'] == 200: + return object_url + else: + return "" + + except Exception as e: + logging.warning(f"synthesis runtime exception : {e}") + + +def synthesis_single(front_image, back_image): + result_image = None + if front_image: + result_image = front_image + if back_image: + result_image.paste(back_image, (0, 0), back_image) + + with io.BytesIO() as output: + result_image.save(output, format='PNG') + data = output.getvalue() + + # image_data = io.BytesIO() + # result_image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + object_name = f'result_{generate_uuid()}.png' + response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + object_url = f"aida-results/{object_name}" + if response['ResponseMetadata']['HTTPStatusCode'] == 200: + return object_url + else: + return "" diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py new file mode 100644 index 0000000..f945b02 --- /dev/null +++ b/app/service/design/utils/upload_image.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :upload_image.py +@Author :周成融 +@Date :2023/8/28 13:49:20 +@detail : +""" +import io +import logging +import time + +import boto3 +import cv2 +from minio import Minio + +from app.core.config import * +from app.service.utils.decorator import RunTime + +minio_client = Minio( + f"{MINIO_URL}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + +"""S3 上传""" +s3 = boto3.client( + 's3', + aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", + aws_secret_access_key="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", + region_name="ap-east-1" +) + + +@RunTime +def upload_png_mask(front_image, object_name, mask=None): + start_time = time.time() + mask_url = None + if mask is not None: + # 反转掩模 + mask_inverted = cv2.bitwise_not(mask) + # 将掩模转换为 RGBA 格式 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + # 将图像数据保存到内存中的 BytesIO 对象中 + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + image_bytes.seek(0) + try: + key = f"mask/mask_{object_name}.png" + mask_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') + except Exception as e: + print(f'上传到 S3 失败: {e}') + with io.BytesIO() as output: + front_image.save(output, format='PNG') + data = output.getvalue() + # 创建一个 S3 客户端 + try: + key = f"image/image_{object_name}.png" + image_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') + return front_image, image_url, mask_url + except Exception as e: + print(f'上传到 S3 失败: {e}') + + +@RunTime +def upload_layer_image(image, object_name): + with io.BytesIO() as output: + image.save(output, format='PNG') + data = output.getvalue() + # 创建一个 S3 客户端 + try: + key = f"image/image_{object_name}.png" + image_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') + return image_url + except Exception as e: + print(f'上传到 S3 失败: {e}') + + +@RunTime +def upload_mask_image(mask, object_name): + # 反转掩模 + mask_inverted = cv2.bitwise_not(mask) + # 将掩模转换为 RGBA 格式 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + # 将图像数据保存到内存中的 BytesIO 对象中 + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + image_bytes.seek(0) + try: + key = f"mask/mask_{object_name}.png" + mask_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') + return mask_url + except Exception as e: + print(f'上传到 S3 失败: {e}') + + +"""minio 上传""" + +# @RunTime +# def upload_png_mask(front_image, object_name, mask=None): +# start_time = time.time() +# try: +# mask_url = None +# if mask is not None: +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# +# image_bytes.seek(0) +# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# +# image_data = io.BytesIO() +# front_image.save(image_data, format='PNG') +# image_data.seek(0) +# image_bytes = image_data.read() +# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# # print(f"upload_png_mask {object_name} = {time.time() - start_time}") +# return front_image, image_url, mask_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") +# +# +# @RunTime +# def upload_layer_image(image, object_name): +# try: +# image_data = io.BytesIO() +# image.save(image_data, format='PNG') +# image_data.seek(0) +# image_bytes = image_data.read() +# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# return image_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") +# +# +# @RunTime +# def upload_mask_image(mask, object_name): +# try: +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# +# image_bytes.seek(0) +# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# return mask_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") From 13fec64125b07eb1e53dab6420a02339041eca9e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 29 May 2024 11:12:59 +0800 Subject: [PATCH 002/108] =?UTF-8?q?feat=20chat=20robot=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_chat_robot.py | 27 +++ app/api/api_prompt_generation.py | 28 +++ app/api/api_route.py | 5 + app/core/config.py | 28 ++- app/schemas/chat_robot.py | 8 + app/schemas/prompt_generation.py | 5 + .../chat_robot/script/agents/__init__.py | 7 + .../script/agents/agent_executor.py | 132 ++++++++++++ .../agents/conversational_functions_agent.py | 198 ++++++++++++++++++ .../chat_robot/script/callbacks/__init__.py | 6 + .../callbacks/openai_token_record_callback.py | 46 ++++ app/service/chat_robot/script/database.py | 79 +++++++ app/service/chat_robot/script/main.py | 114 ++++++++++ .../chat_robot/script/memory/__init__.py | 3 + .../script/memory/user_buffer_window.py | 93 ++++++++ app/service/chat_robot/script/prompt.py | 52 +++++ .../chat_robot/script/tools/__init__.py | 10 + .../chat_robot/script/tools/sql_tools.py | 183 ++++++++++++++++ .../chat_robot/script/tools/tutorial_tool.py | 19 ++ .../chat_robot/script/utils/__init__.py | 1 + app/service/chat_robot/script/utils/logger.py | 26 +++ .../chatgpt_for_translation.py | 70 +++++++ requirements.txt | Bin 814 -> 1160 bytes 23 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 app/api/api_chat_robot.py create mode 100644 app/api/api_prompt_generation.py create mode 100644 app/schemas/chat_robot.py create mode 100644 app/schemas/prompt_generation.py create mode 100644 app/service/chat_robot/script/agents/__init__.py create mode 100644 app/service/chat_robot/script/agents/agent_executor.py create mode 100644 app/service/chat_robot/script/agents/conversational_functions_agent.py create mode 100644 app/service/chat_robot/script/callbacks/__init__.py create mode 100644 app/service/chat_robot/script/callbacks/openai_token_record_callback.py create mode 100644 app/service/chat_robot/script/database.py create mode 100644 app/service/chat_robot/script/main.py create mode 100644 app/service/chat_robot/script/memory/__init__.py create mode 100644 app/service/chat_robot/script/memory/user_buffer_window.py create mode 100644 app/service/chat_robot/script/prompt.py create mode 100644 app/service/chat_robot/script/tools/__init__.py create mode 100644 app/service/chat_robot/script/tools/sql_tools.py create mode 100644 app/service/chat_robot/script/tools/tutorial_tool.py create mode 100644 app/service/chat_robot/script/utils/__init__.py create mode 100644 app/service/chat_robot/script/utils/logger.py create mode 100644 app/service/prompt_generation/chatgpt_for_translation.py diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py new file mode 100644 index 0000000..c394046 --- /dev/null +++ b/app/api/api_chat_robot.py @@ -0,0 +1,27 @@ +import logging +import time +from fastapi import APIRouter + +from app.schemas.chat_robot import ChatRobotModel +from app.service.chat_robot.script.main import chat + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/chat_robot") +def chat_robot(request_data: ChatRobotModel): + try: + logger.info(f"chat_robot request item is : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + data = chat(post_data=request_data) + logger.info(f"chat_robot Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"chat_robot Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py new file mode 100644 index 0000000..5e71eec --- /dev/null +++ b/app/api/api_prompt_generation.py @@ -0,0 +1,28 @@ +import logging +import time + +from fastapi import APIRouter + +from app.schemas.prompt_generation import PromptGenerationImageModel +from app.service.prompt_generation.chatgpt_for_translation import translate_to_en + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/translateToEN") +def prompt_generation(request_data: PromptGenerationImageModel): + try: + logger.info(f"prompt_translate to English request data : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + data = translate_to_en(request_data.text) + logger.info(f"prompt_generation Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_route.py b/app/api/api_route.py index ff21b34..c1add93 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -5,6 +5,9 @@ from app.api import api_super_resolution from app.api import api_generate_image from app.api import api_attribute_retrieve from app.api import api_design +from app.api import api_chat_robot +from app.api import api_prompt_generation + router = APIRouter() @@ -13,3 +16,5 @@ router.include_router(api_super_resolution.router, tags=["super_resolution"], pr router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api") router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api") router.include_router(api_design.router, tags=['design'], prefix="/api") +router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") +router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index 6e22adc..5744dec 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -61,6 +61,32 @@ MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" +# Mysql 配置 +DB_HOST = '18.167.251.121' # 数据库主机地址 +# DB_PORT = int( 33006) +DB_PORT = 33008 # 数据库端口 +DB_USERNAME = 'aida_con_python' # 数据库用户名 +DB_PASSWORD = '123456' # 数据库密码 +DB_NAME = 'aida' # 数据库库名 + +# openai +os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c" +OPENAI_STREAM = True +BUFFER_THRESHOLD = 6 # must be even number +SINGLE_TOKEN_THRESHOLD = 200 +TOKEN_THRESHOLD = 600 +OPENAI_TEMPERATURE = 0 + +# OPENAI_API_KEY = "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" +OPENAI_API_KEY = "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" +OPENAI_MODEL = "gpt-3.5-turbo-0613" +OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", } + # attribute service config ATT_TRITON_URL = "10.1.1.240:10000" diff --git a/app/schemas/chat_robot.py b/app/schemas/chat_robot.py new file mode 100644 index 0000000..cebf74a --- /dev/null +++ b/app/schemas/chat_robot.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class ChatRobotModel(BaseModel): + gender: str + message: str + session_id: str + user_id: int diff --git a/app/schemas/prompt_generation.py b/app/schemas/prompt_generation.py new file mode 100644 index 0000000..195291b --- /dev/null +++ b/app/schemas/prompt_generation.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class PromptGenerationImageModel(BaseModel): + text: str diff --git a/app/service/chat_robot/script/agents/__init__.py b/app/service/chat_robot/script/agents/__init__.py new file mode 100644 index 0000000..30c40f9 --- /dev/null +++ b/app/service/chat_robot/script/agents/__init__.py @@ -0,0 +1,7 @@ +from .agent_executor import CustomAgentExecutor +from .conversational_functions_agent import ConversationalFunctionsAgent + +__all__ = [ + "CustomAgentExecutor", + "ConversationalFunctionsAgent" +] diff --git a/app/service/chat_robot/script/agents/agent_executor.py b/app/service/chat_robot/script/agents/agent_executor.py new file mode 100644 index 0000000..cc69936 --- /dev/null +++ b/app/service/chat_robot/script/agents/agent_executor.py @@ -0,0 +1,132 @@ +import inspect +import json +import logging +from typing import Any, Dict, List, Optional, Union, Tuple + +from langchain.agents import AgentExecutor +from langchain.callbacks.manager import Callbacks, CallbackManager +from langchain.load.dump import dumpd +from langchain.schema import RUN_KEY, RunInfo +from langchain_core.agents import AgentAction, AgentFinish + + +class CustomAgentExecutor(AgentExecutor): + def __call__( + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, + session_key: str = "", + *, + tags: Optional[List[str]] = None, + include_run_info: bool = False, + ) -> Dict[str, Any]: + """Run the logic of this chain and add to output if desired. + + Args: + inputs: Dictionary of inputs, or single input if chain expects + only one param. + return_only_outputs: boolean for whether to return only outputs in the + response. If True, only new keys generated by this chain will be + returned. If False, both input keys and new keys generated by this + chain will be returned. Defaults to False. + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. + """ + inputs = self.prep_inputs(inputs, session_key) + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose, tags, self.tags + ) + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( + dumpd(self), + inputs, + ) + try: + outputs = ( + self._call(inputs, run_manager=run_manager) + if new_arg_supported + else self._call(inputs) + ) + except (KeyboardInterrupt, Exception) as e: + logging.exception(e) + run_manager.on_chain_error(e) + raise e + run_manager.on_chain_end(outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs, session_key + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs + + def prep_outputs( + self, + inputs: Dict[str, str], + outputs: Dict[str, str], + return_only_outputs: bool = False, + session_key: str = "" + ) -> Dict[str, str]: + """Validate and prep outputs.""" + self._validate_outputs(outputs) + if self.memory is not None and outputs['need_record']: + self.memory.save_context(inputs, outputs, session_key) + if return_only_outputs: + return outputs + else: + return {**inputs, **outputs} + + def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]: + """Validate and prep inputs.""" + if not isinstance(inputs, dict): + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + if len(_input_keys) != 1: + raise ValueError( + f"A single string input was passed in, but this chain expects " + f"multiple inputs ({_input_keys}). When a chain expects " + f"multiple inputs, please call it by passing in a dictionary, " + "eg `chain({'foo': 1, 'bar': 2})`" + ) + inputs = {list(_input_keys)[0]: inputs} + if self.memory is not None: + external_context = self.memory.load_memory_variables(inputs, session_key) + inputs = dict(inputs, **external_context) + self._validate_inputs(inputs) + return inputs + + def _get_tool_return( + self, next_step_output: Tuple[AgentAction, str] + ) -> Optional[AgentFinish]: + """Check if the tool is a returning tool.""" + agent_action, observation = next_step_output + name_to_tool_map = {tool.name: tool for tool in self.tools} + return_value_key = "output" + + if len(self.agent.return_values) > 0: + return_value_key = self.agent.return_values[0] + + try: + observation_list = json.loads(observation) + if agent_action.tool == "sql_db_query" and isinstance(observation_list, + list) and observation_list.__len__() != 0: + return AgentFinish( + {return_value_key: observation}, + "", + ) + except: + pass + + # Invalid tools won't be in the map, so we return False. + if agent_action.tool in name_to_tool_map: + if name_to_tool_map[agent_action.tool].return_direct: + return AgentFinish( + {return_value_key: observation}, + "", + ) + return None diff --git a/app/service/chat_robot/script/agents/conversational_functions_agent.py b/app/service/chat_robot/script/agents/conversational_functions_agent.py new file mode 100644 index 0000000..eb362a7 --- /dev/null +++ b/app/service/chat_robot/script/agents/conversational_functions_agent.py @@ -0,0 +1,198 @@ +import json +import re +from json import JSONDecodeError +from typing import List, Tuple, Any, Union +from dataclasses import dataclass + +from langchain.callbacks.manager import Callbacks +from langchain.agents import ( + OpenAIFunctionsAgent, +) +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseMessage, + OutputParserException +) +from langchain.schema.messages import ( + AIMessage, + FunctionMessage +) +from langchain.tools import BaseTool, StructuredTool +# from langchain.tools.convert_to_openai import FunctionDescription +from langchain.utils.openai_functions import FunctionDescription + + +@dataclass +class _FunctionsAgentAction(AgentAction): + """Add message_log to AgentAction class for the _FunctionAgentAction + """ + message_log: List[BaseMessage] + + def __init__( + self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any + ): + """Override init to support instantiation by position for backward compat.""" + super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) + + +def _convert_agent_action_to_messages( + agent_action: AgentAction, observation: str +) -> List[BaseMessage]: + """Convert an agents action to a message. + + This code is used to reconstruct the original AI message from the agents action. + + Args: + agent_action: Agent action to convert. + + Returns: + AIMessage that corresponds to the original tools invocation. + """ + if isinstance(agent_action, _FunctionsAgentAction): + return agent_action.message_log + [ + _create_function_message(agent_action, observation) + ] + else: + return [AIMessage(content=agent_action.log)] + + +def _create_function_message( + agent_action: AgentAction, observation: str +) -> FunctionMessage: + """Convert agents action and observation into a function message. + Args: + agent_action: the tools invocation request from the agents + observation: the result of the tools invocation + Returns: + FunctionMessage that corresponds to the original tools invocation + """ + if not isinstance(observation, str): + try: + content = json.dumps(observation, ensure_ascii=False) + except Exception: + content = str(observation) + else: + content = observation + return FunctionMessage( + name=agent_action.tool, + content=content, + ) + + +def _format_intermediate_steps( + intermediate_steps: List[Tuple[AgentAction, str]], +) -> List[BaseMessage]: + """Format intermediate steps. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + Returns: + list of messages to send to the LLM for the next prediction + """ + messages = [] + + for intermediate_step in intermediate_steps: + agent_action, observation = intermediate_step + messages.extend(_convert_agent_action_to_messages(agent_action, observation)) + + return messages + + +def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: + """Format tools into the OpenAI function API.""" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": tool.param_description if hasattr(tool, 'param_description') else "", + }, + }, + "required": ["query"], + } + + return { + "name": tool.name, + "description": tool.description, + "parameters": parameters, + } + + +def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: + if not isinstance(message, AIMessage): + raise TypeError(f"Expected an AI message but got {type(message)}") + + function_call = message.additional_kwargs.get("function_call", {}) + + if function_call: + function_call = message.additional_kwargs["function_call"] + function_name = function_call["name"] + try: + _tool_input = json.loads(function_call["arguments"]) + except JSONDecodeError: + raise OutputParserException( + f"Could not parse tools input: {function_call} because" + f"the `arguments` is not valid JSON." + ) + + if "query" in _tool_input: + tool_input = _tool_input["query"] + else: + tool_input = _tool_input + + return _FunctionsAgentAction( + tool=function_name, + tool_input=tool_input, + log=f"\nInvoking: `{function_name}` with `{tool_input}`\n", + message_log=[message] + ) + + # pattern = r'\((.*?)\)' + # matches = re.findall(pattern, message.content) + # result = [] + # + # for match in matches: + # result.append(match) + # + # if result: + # output = result + # else: + # output = message.content + + return AgentFinish(return_values={"output": message.content}, log=message.content) + + +class ConversationalFunctionsAgent(OpenAIFunctionsAgent): + @property + def functions(self) -> List[dict]: + return [dict(_format_tool_to_openai_function(t)) for t in self.tools] + + def plan(self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Decide how agents should move after receiving an input. The difference between + OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser + into it. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + **kwargs: Including user's input string + + Returns: + Action specifying what tools to use. + """ + agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) + messages: List[BaseMessage] = prompt.to_messages() + predicted_message = self.llm.predict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + agent_decision = _parse_ai_message(predicted_message) + return agent_decision diff --git a/app/service/chat_robot/script/callbacks/__init__.py b/app/service/chat_robot/script/callbacks/__init__.py new file mode 100644 index 0000000..8f644bd --- /dev/null +++ b/app/service/chat_robot/script/callbacks/__init__.py @@ -0,0 +1,6 @@ +from .openai_token_record_callback import OpenAITokenRecordCallbackHandler + + +__all__ = [ + 'OpenAITokenRecordCallbackHandler' +] diff --git a/app/service/chat_robot/script/callbacks/openai_token_record_callback.py b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py new file mode 100644 index 0000000..64ed7f4 --- /dev/null +++ b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py @@ -0,0 +1,46 @@ +"""Callback Handler that add on_chain_end function to record Token usage.""" +from typing import Any, Dict + +from langchain.callbacks import OpenAICallbackHandler +from langchain.schema import LLMResult +from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model + + +class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler): + need_record: bool = True + response_type: str = "string" + """Callback Handler that tracks OpenAI info and write to redis after agent finish""" + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + self.successful_requests += 1 + if "token_usage" not in response.llm_output: + return None + if "function_call" in response.generations[0][0].message.additional_kwargs: + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]: + self.need_record = False + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query": + self.response_type = "image" + token_usage = response.llm_output["token_usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + model_name = standardize_model_name(response.llm_output.get("model_name", "")) + if model_name in MODEL_COST_PER_1K_TOKENS: + completion_cost = get_openai_token_cost_for_model( + model_name, completion_tokens, is_completion=True + ) + prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens) + self.total_cost += prompt_cost + completion_cost + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None: + """Write token usage to redis.""" + outputs["total_tokens"] = self.total_tokens + outputs["total_cost"] = self.total_cost + outputs["prompt_tokens"] = self.prompt_tokens + outputs["completion_tokens"] = self.completion_tokens + outputs["need_record"] = self.need_record + outputs["response_type"] = self.response_type diff --git a/app/service/chat_robot/script/database.py b/app/service/chat_robot/script/database.py new file mode 100644 index 0000000..8a5dfdb --- /dev/null +++ b/app/service/chat_robot/script/database.py @@ -0,0 +1,79 @@ +from typing import Optional, List +import json + +from sqlalchemy import text +# from langchain import SQLDatabase +from langchain.utilities import SQLDatabase + + +class CustomDatabase(SQLDatabase): + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + # def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + connection = self._engine.connect() + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + # raise ValueError(f"table_names {missing_tables} not found in database") + return f"Table {','.join(missing_tables)} can not be found in the database" + all_table_names = table_names + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + ] + + tables = [] + for table in meta_tables: + table_name = table.name + column_names = table.columns.keys() + table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n" + for column_name in column_names: + if column_name not in ["ID", "img_name"]: + query = text(f"SELECT DISTINCT {column_name} FROM {table_name}") + result = connection.execute(query) + enum_values: List[str] = [row[0] for row in result.fetchall()] + column_info = f"{column_name}: {', '.join(enum_values)}\n" + table_info += column_info + + # table_info = f"Table: {table_name}\n" + # + # if self._sample_rows_in_table_info: + # table_info += f"{self._get_sample_rows(table)}\n" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def run(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + """ + with self._engine.begin() as connection: + if self._schema is not None: + if self.dialect == "snowflake": + connection.exec_driver_sql( + f"ALTER SESSION SET search_path='{self._schema}'" + ) + elif self.dialect == "bigquery": + connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") + else: + connection.exec_driver_sql(f"SET search_path TO {self._schema}") + cursor = connection.execute(text(command)) + if cursor.rowcount: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone() # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + + # Convert columns values to string to avoid issues with sqlalchmey + # trunacating text + if isinstance(result, list): + return json.dumps([r[0] for r in result]) + + return json.dumps([result[0]]) + return "" diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py new file mode 100644 index 0000000..2a62664 --- /dev/null +++ b/app/service/chat_robot/script/main.py @@ -0,0 +1,114 @@ +import logging +from loguru import logger +from langchain.agents import Tool +from langchain.utilities import SerpAPIWrapper +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder +from langchain.schema import SystemMessage, AIMessage +from langchain.chat_models import ChatOpenAI +from langchain.llms.openai import OpenAI +from langchain.callbacks import FileCallbackHandler +from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent +from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler +from app.service.chat_robot.script.database import CustomDatabase +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX +from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) +from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory +from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool +from app.core.config import * + +import os + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" +# log callbacks +logfile = "logs/chat_debug.log" +logger.add(logfile, colorize=True, enqueue=True) +log_handler = FileCallbackHandler(logfile) + +# Initiate our LLM 'gpt-3.5-turbo' +llm = ChatOpenAI(temperature=0.1, + openai_api_key=OPENAI_API_KEY, + # callbacks=[OpenAICallbackHandler()] + ) + +search = SerpAPIWrapper() +db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3', + include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress', + 'female_outwear', 'male_bottom', 'male_top', 'male_outwear'], + engine_args={"pool_recycle": 7200}) +tools = [ + Tool( + name="internet_search", + description="Can be used to perform Internet searches", + func=search.run + ), + QuerySQLDataBaseTool(db=db, return_direct=False), + InfoSQLDatabaseTool(db=db), + ListSQLDatabaseTool(db=db), + QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)), + Tool( + name="tutorial_tool", + description="Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string", + func=CustomTutorialTool(), + return_direct=True + ) +] + +messages = [ + SystemMessage(content=FASHION_CHAT_BOT_PREFIX), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template( + "{input} " + "Question from a {gender}." + ), + AIMessage(content=TOOLS_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), +] + +prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages) +agent = ConversationalFunctionsAgent( + llm=llm, + tools=tools, + prompt=prompt +) + +memory = UserConversationBufferWindowMemory.from_redis( + return_messages=True, k=2, input_key='input', output_key='output' +) +agent_executor = CustomAgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + verbose=True, + memory=memory, +) + + +def chat(post_data): + user_id = post_data.user_id + session_id = post_data.session_id + input_message = post_data.message + gender = post_data.gender + + final_outputs = agent_executor( + {"input": input_message, "gender": gender}, + callbacks=[OpenAITokenRecordCallbackHandler(), log_handler], + session_key=f"buffer:{user_id}:{session_id}", + ) + api_response = { + 'user_id': user_id, + 'session_id': session_id, + # 'message_id': message_id, + # 'create_time': created_time, + 'input': final_outputs['input'], + # 'conversion': messages, + 'output': final_outputs['output'], + # 'gpt_response_time': gpt_response_time, + 'total_tokens': final_outputs['total_tokens'], + 'total_cost': final_outputs['total_cost'], + 'prompt_tokens': final_outputs['prompt_tokens'], + 'completion_tokens': final_outputs['completion_tokens'], + 'response_type': final_outputs['response_type'] + } + logging.info(api_response) + return api_response diff --git a/app/service/chat_robot/script/memory/__init__.py b/app/service/chat_robot/script/memory/__init__.py new file mode 100644 index 0000000..9586157 --- /dev/null +++ b/app/service/chat_robot/script/memory/__init__.py @@ -0,0 +1,3 @@ +from .user_buffer_window import UserConversationBufferWindowMemory + +__all__ = ['UserConversationBufferWindowMemory'] diff --git a/app/service/chat_robot/script/memory/user_buffer_window.py b/app/service/chat_robot/script/memory/user_buffer_window.py new file mode 100644 index 0000000..9fbc2d6 --- /dev/null +++ b/app/service/chat_robot/script/memory/user_buffer_window.py @@ -0,0 +1,93 @@ +import logging +from typing import Any, Dict, List, Tuple +import json + +import redis +from redis import Redis +from langchain.memory import RedisChatMessageHistory +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage +from langchain.schema.messages import _message_to_dict, messages_from_dict +from langchain.memory.utils import get_prompt_input_key + +from app.core.config import * + + +class UserConversationBufferWindowMemory(BaseChatMemory): + """Buffer for storing conversation memory.""" + + redis_client: Redis + human_prefix: str = "Human" + ai_prefix: str = "AI" + memory_key: str = "history" #: :meta private: + k: int = 5 + + @classmethod + def from_redis( + cls, + host: str = REDIS_HOST, + port: int = REDIS_PORT, + db: int = 3, + **kwargs + ): + redis_client = Redis(host=host, port=port, db=db) + try: + response = redis_client.ping() + if response: + print("Connect to redis server successfully.") + logging.info("Connect to redis server successfully.") + else: + print("Fail to connect to redis server") + logging.info("Fail to connect to redis server") + except redis.RedisError as e: + logging.info(f"Error occurs when connecting to redis server: {str(e)}") + return cls(redis_client=redis_client, **kwargs) + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]: + """Return history buffer.""" + _items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else [] + items = [json.loads(m.decode("utf-8")) for m in _items[::-1]] + buffer = messages_from_dict(items) + if not self.return_messages: + buffer = get_buffer_string( + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + return {self.memory_key: buffer} + + def _get_input_output( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> Tuple[str, str]: + if self.input_key is None: + prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) + else: + prompt_input_key = self.input_key + if self.output_key is None: + if len(outputs) != 1: + raise ValueError(f"One output key expected, got {outputs.keys()}") + output_key = list(outputs.keys())[0] + else: + output_key = self.output_key + return inputs[prompt_input_key], outputs[output_key] + + def add_message(self, key: str, message: BaseMessage) -> None: + self.redis_client.lpush(key, json.dumps(_message_to_dict(message))) + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None: + """Save context from this conversation to buffer.""" + input_str, output_str = self._get_input_output(inputs, outputs) + self.add_message(key, HumanMessage(content=input_str)) + self.add_message(key, AIMessage(content=output_str)) + + # def clear(self, key) -> None: + # """Clear memory contents.""" + # self.redis_client.delete(key) diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py new file mode 100644 index 0000000..a88044d --- /dev/null +++ b/app/service/chat_robot/script/prompt.py @@ -0,0 +1,52 @@ +FASHION_CHAT_BOT_PREFIX = """ +You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +Remember your answer should be very precise and the final output answer should not exceed 20 words. + +You may encounter the following types of questions: +1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +If the question does not seem related to the database, just return "I don't know" as the answer. + +2) If the query related to current events, you should use internet_search to seek help from the internet. + +3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. + +Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +""" + +TOOL_SELECT_SUFFIX = """ +Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly. +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +The use of online resources should be limited to inquiry pertaining to current subjects. +""" + +SQL_FUNCTIONS_SUFFIX = """ +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +""" + +INTERNET_SEARCH_SUFFIX = """ +If the question should be answered using internet search tools, I should seek help from the internet. +""" + +ANSWER_FORMAT_SUFFIX = """ +My final answer are limited to 20 words and be as much precise as possible. +""" + +TOOLS_FUNCTIONS_SUFFIX = ( + "If the input involves clothing queries," + "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." + "All SQL statements must use 'ORDER BY RAND()', for example:" + "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" + "If the input does not involve clothing queries, " + "I should engage in conversation as an assistant or search from internet with internet_search tool." + "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +) + +TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." diff --git a/app/service/chat_robot/script/tools/__init__.py b/app/service/chat_robot/script/tools/__init__.py new file mode 100644 index 0000000..4a40a33 --- /dev/null +++ b/app/service/chat_robot/script/tools/__init__.py @@ -0,0 +1,10 @@ +from .sql_tools import ( + QuerySQLDataBaseTool, + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool +) + +__all__ = [ + "QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool" +] diff --git a/app/service/chat_robot/script/tools/sql_tools.py b/app/service/chat_robot/script/tools/sql_tools.py new file mode 100644 index 0000000..92b8003 --- /dev/null +++ b/app/service/chat_robot/script/tools/sql_tools.py @@ -0,0 +1,183 @@ +# flake8: noqa +"""Tools for interacting with a SQL database.""" +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.chains.llm import LLMChain +from langchain.prompts import PromptTemplate +# from langchain.sql_database import SQLDatabase +from langchain.utilities import SQLDatabase +from langchain.tools.base import BaseTool +from langchain.tools.sql_database.prompt import QUERY_CHECKER + + +class BaseSQLDatabaseTool(BaseModel): + """Base tools for interacting with a SQL database.""" + + db: SQLDatabase = Field(exclude=True) + param_description: str = "" + + # Override BaseTool.Config to appease mypy + # See https://github.com/pydantic/pydantic/issues/4173 + class Config(BaseTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.forbid + + +class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for querying a SQL database.""" + + name = "sql_db_query" + # description = """ + # Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables. + # This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database. + # You should always use ‘order by rand()’ to randomly select data. + # If the query is not correct, an error message will be returned. + # If an error is returned, rewrite the query, check the query, and try again. + # Always limit your query to at most 4 results. + # Never query for all the columns from a specific table, only ask for the relevant columns given the question. + # You MUST double check your query before executing it. + # DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + # """ + + description: str = ( + "The input of this tool is a detailed and correct SQL select query statement, " + "and the output is the result of the database, and it can only return up to 4 results." + "If the query is not correct, an error message will be returned." + "If an error is returned, rewrite the query, check the query, and try again." + "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," + "use sql_db_schema to query the correct table fields." + + "Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() " + "LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' " + "order by rand() LIMIT 2'" + ) + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + result = self.db.run_no_throw(query) + return result + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("QuerySqlDbTool does not support async") + + +class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting metadata about a SQL database.""" + + name = "sql_db_schema" + # description = """ + # The database contains information of lots of fashion items, such as item name, their fashion attributes. + # There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear. + # Find the most relevant tables with the query, and output the schema of these tables. + # """ + + description: str = ( + "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables." + "There are eight tables covering eight fashion categories: female_top, female_pants, female_dress," + "female_skirt, female_outwear, male_bottom, male_top, and male_outwear." + + "Example Input: 'female_outwear, male_top'" + ) + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw(table_names.split(", ")) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SchemaSqlDbTool does not support async") + + +class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting tables names.""" + + name = "sql_db_list_tables" + description = "Input is an empty string, output is a comma separated list of tables in the database." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("ListTablesSqlDbTool does not support async") + + +class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: LLMChain = Field(init=False) + name = "sql_db_query_checker" + description = ( + "Use this tools to double check if your query is correct before executing it." + "Always use this tools before executing a query with sql_db_query!" + ) + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, + input_variables=["query", "dialect"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: + # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: + raise ValueError( + "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict(query=query, dialect=self.db.dialect) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict(query=query, dialect=self.db.dialect) diff --git a/app/service/chat_robot/script/tools/tutorial_tool.py b/app/service/chat_robot/script/tools/tutorial_tool.py new file mode 100644 index 0000000..c08eb9d --- /dev/null +++ b/app/service/chat_robot/script/tools/tutorial_tool.py @@ -0,0 +1,19 @@ +from typing import Any + +from langchain.tools.base import BaseTool + +from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN + + +# 处理系统引导教程相关的输入 +class CustomTutorialTool(BaseTool): + name = "tutorial_tool" + + description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string") + + def _run(self, tool_input, **kwargs: Any) -> str: + return TUTORIAL_TOOL_RETURN + + async def _arun(self, tool_input, **kwargs: Any) -> str: + raise NotImplementedError("CustomTutorialTool does not support async") diff --git a/app/service/chat_robot/script/utils/__init__.py b/app/service/chat_robot/script/utils/__init__.py new file mode 100644 index 0000000..92a2f16 --- /dev/null +++ b/app/service/chat_robot/script/utils/__init__.py @@ -0,0 +1 @@ +from .logger import Logger diff --git a/app/service/chat_robot/script/utils/logger.py b/app/service/chat_robot/script/utils/logger.py new file mode 100644 index 0000000..cb52c18 --- /dev/null +++ b/app/service/chat_robot/script/utils/logger.py @@ -0,0 +1,26 @@ +import logging +from logging import handlers + + +class Logger(object): + level_relations = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'crit': logging.CRITICAL + } + + def __init__(self, filename, level='info', when='D', backCount=3, + fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'): + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) # set log format + self.logger.setLevel(self.level_relations.get(level)) # set log level + sh = logging.StreamHandler() # output to terminal + sh.setFormatter(format_str) # set format for terminal log + th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount, + encoding='utf-8') # log into file + + th.setFormatter(format_str) # set format for file log + self.logger.addHandler(sh) # output to terminal + self.logger.addHandler(th) # output to file diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py new file mode 100644 index 0000000..b9c2c80 --- /dev/null +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -0,0 +1,70 @@ +import os + +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, \ + PromptTemplate + +from app.core.config import OPENAI_MODEL, OPENAI_API_KEY + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" + + +llm = ChatOpenAI(model_name=OPENAI_MODEL, + openai_api_key=OPENAI_API_KEY, + temperature=0) + + +def translate_to_en(text): + template = ( + """You are a translation expert, proficient in various languages. + And can translate various languages into English. + Please translate to grammatically correct English regardless of the input language. + If the input is in English, check for grammatical errors. If there are no errors, simply output the sentence. + If there are grammatical errors, correct them and then output the sentence.""" + ) + system_message_prompt = SystemMessagePromptTemplate.from_template(template) + + # 待翻译文本由 Human 角色输入 + human_template = "User input : {text}" + human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template) + + # 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate + chat_prompt_template = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template) + + template = ( + """ + Input sentence: + {translate} + 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images, + ensuring all key nouns or adjectives related to the image are retained. + 2. Simplify complex sentence structures and clarify ambiguous expressions. + 3. Only Output the adjusted English sentence. + + Output : + """ + ) + # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words." + # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding. + + prompt_template = PromptTemplate(input_variables=["translate"], template=template) + prompt_chain = LLMChain(llm=llm, prompt=prompt_template) + + from langchain.chains import SimpleSequentialChain + overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True) + + response = overall_chain.run(text) + return response + + +def main(): + """Main function""" + translate_to_en("生成一件运动风格的夹克,带有拉链和口袋,适合休闲穿着") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 152908247f5c59e3953542a8479bc450084ee395..e3f2934da41cde5c502fe10afc3bdb9faf27eb36 100644 GIT binary patch delta 348 zcmZ`#!3u&<5F8bPg8rc&P|}RLbna46^aWO8k))_E=;WbG_x>Y-j{RJ-FB1%6S>CcU zJ3IUCt!w*X8kUQ}(=skbxkG>gJ!D9*W=b96h!eW#AZ90mBC0T9^xj7L+lBtfc;ixIKmY6}jA8eQOw vIdg?2W;O0&>Qv(q1?1M4BC;mrMVY?L7N6`WVg+9j_bBey|G@RP+{= Date: Thu, 30 May 2024 09:48:13 +0800 Subject: [PATCH 003/108] =?UTF-8?q?feat=20design=20=E9=A2=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=8E=A5=E5=8F=A3=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design_pre_processing.py | 29 ++ app/api/api_route.py | 2 + app/core/config.py | 3 + app/schemas/pre_processing.py | 5 + .../design/items/pipelines/keypoints.py | 2 +- app/service/design/utils/design_ensemble.py | 2 +- app/service/design_pre_processing/service.py | 320 ++++++++++++++++++ 7 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 app/api/api_design_pre_processing.py create mode 100644 app/schemas/pre_processing.py create mode 100644 app/service/design_pre_processing/service.py diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py new file mode 100644 index 0000000..0c0089d --- /dev/null +++ b/app/api/api_design_pre_processing.py @@ -0,0 +1,29 @@ +import logging +import time + +from fastapi import APIRouter + +from app.schemas.pre_processing import DesignPreProcessingModel +from app.service.design_pre_processing.service import DesignPreprocessing + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/design_pre_processing") +def design_pre_processing(request_data: DesignPreProcessingModel): + try: + logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + server = DesignPreprocessing() + data = server.pipeline(image_list=request_data.sketches) + logger.info(f"design_pre_processing Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"design Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_route.py b/app/api/api_route.py index c1add93..c2bd2d2 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -7,6 +7,7 @@ from app.api import api_attribute_retrieve from app.api import api_design from app.api import api_chat_robot from app.api import api_prompt_generation +from app.api import api_design_pre_processing router = APIRouter() @@ -18,3 +19,4 @@ router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"] router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") +router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index 5744dec..cca1de0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,6 +118,9 @@ AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') +# DESIGN 预处理 +IF_DEBUG_SHOW = False + # 优先级 PRIORITY_DICT = { 'earring_front': 99, diff --git a/app/schemas/pre_processing.py b/app/schemas/pre_processing.py new file mode 100644 index 0000000..47d9297 --- /dev/null +++ b/app/schemas/pre_processing.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DesignPreProcessingModel(BaseModel): + sketches: list[dict] diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index fc59b61..4d0a081 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,8 +34,8 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - # 取消向量查询 直接过模型推理 keypoint_cache = self.keypoint_cache(result, site) + # 取消向量查询 直接过模型推理 # keypoint_cache = False if keypoint_cache is False: diff --git a/app/service/design/utils/design_ensemble.py b/app/service/design/utils/design_ensemble.py index e1df56a..a1021e9 100644 --- a/app/service/design/utils/design_ensemble.py +++ b/app/service/design/utils/design_ensemble.py @@ -37,7 +37,7 @@ def get_keypoint_result(image, site): keypoint_result = None try: image, scale_factor = keypoint_preprocess(image) - client = httpclient.InferenceServerClient(url=KEYPOINT_MODEL_URL) + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) transformed_img = image.astype(np.float32) inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py new file mode 100644 index 0000000..e655087 --- /dev/null +++ b/app/service/design_pre_processing/service.py @@ -0,0 +1,320 @@ +import logging +import time + +import cv2 +import numpy as np +import torch +from minio import Minio +from pymilvus import connections, Collection +from urllib3.exceptions import ResponseError +import torch.nn.functional as F +import tritonclient.grpc as grpcclient +import io + +from app.core.config import * +from app.service.design.utils.design_ensemble import get_keypoint_result + + +class DesignPreprocessing: + def __init__(self): + self.minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + # @ RunTime + def pipeline(self, image_list): + sketches_list = self.read_image(image_list) + logging.info("read image success") + + bounding_box_sketches_list = self.bounding_box(sketches_list) + logging.info("bounding box image success") + + super_resolution_list = self.super_resolution(bounding_box_sketches_list) + logging.info("super_resolution_list image success") + + infer_sketches_list = self.infer_image(super_resolution_list) + logging.info("infer image success") + + result = self.composing_image(infer_sketches_list) + logging.info("Replenish white edge image success") + + for d in result: + if 'image_obj' in d: + del d['image_obj'] + if 'obj' in d: + del d['obj'] + if 'keypoint_result' in d: + del d['keypoint_result'] + return result + + def read_image(self, image_list): + for obj in image_list: + file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif image.shape[2] == 4: # 如果是四通道 mask + image = image[:, :, :3] + obj["image_obj"] = image + return image_list + + # @ RunTime + def bounding_box(self, image_list): + for item in image_list: + image = item['image_obj'] + # 使用Canny边缘检测来检测物体的轮廓 + edges = cv2.Canny(image, 50, 150) + # 查找轮廓 + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 初始化包围所有外接矩形的大矩形的坐标 + x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 + # 遍历所有外接矩形,更新大矩形的坐标 + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + x_min = min(x_min, x) + y_min = min(y_min, y) + x_max = max(x_max, x + w) + y_max = max(y_max, y + h) + + if IF_DEBUG_SHOW: + image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) + cv2.imshow("bounding_box image", image_with_big_rect) + cv2.waitKey(0) + + # 根据大矩形的坐标来裁剪原始图像 + if len(contours) > 0: + cropped_image = image[y_min:y_max, x_min:x_max] + item['obj'] = cropped_image # 新shape图像 + # 取消直接覆盖,新增size判断 + # try: + # # 覆盖到minio + # image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes() + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + # except ResponseError as err: + # print(f"Error: {err}") + else: + item['obj'] = image + return image_list + + def super_resolution(self, image_list): + for item in image_list: + # 判断 两边是否同时都小于512 因为此处做四倍超分 + if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: + # 如果任意一边小于256则超分 + if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256: + # 超分 + img = item['obj'].astype(np.float32) / 255. + sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) + sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() + inputs = [ + grpcclient.InferInput("input", sample.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(sample) + triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) + result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) + result_image = result.as_numpy(f'output')[0] + sr_output = torch.tensor(result_image) + output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) + item['obj'] = output + try: + # 覆盖到minio + image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() + self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + except ResponseError as err: + print(f"Error: {err}") + return image_list + + # @ RunTime + def infer_image(self, image_list): + for sketch in image_list: + # 小写 + image_category = sketch['image_category'].lower() + # 判断上下装 + sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # 推理得到keypoint + sketch['keypoint_result'] = self.keypoint_cache(sketch) + + if IF_DEBUG_SHOW: + debug_show_image = sketch['obj'].copy() + points_list = [] + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for i in sketch['keypoint_result'].values(): + points_list.append((int(i[1]), int(i[0]))) + for point in points_list: + cv2.circle(debug_show_image, point, point_size, point_color, thickness) + cv2.imshow("", debug_show_image) + cv2.waitKey(0) + # # 关键点在上部则推理seg + # if sketch["site"] == "up": + # # 判断seg缓存是否存在,是否与当前图片shape一致 + # seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape) + # if seg_result is False: + # # 推理seg + 保存 + # seg_result = get_seg_result(sketch['image_id'], sketch['obj']) + return image_list + + # @ RunTime + def composing_image(self, image_list): + for image in image_list: + if image['site'] == 'down': + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image['obj'].shape[1]: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + scale = 0.4 + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1] + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + return image_list + + @staticmethod + def select_seg_result(image_id, image_obj): + try: + # 如果shape不匹配 返回false + result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) + if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: + return result + else: + return False + except FileNotFoundError as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + @staticmethod + def search_seg_result(image_id, ori_shape): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. + # collection.load() + # start_time = time.time() + # res = collection.query( + # expr=f"seg_id == {image_id}", + # offset=0, + # limit=10, + # output_fields=["seg_cache"], + # ) + # logging.info(f"search seg cache time : {time.time() - start_time}") + + # if len(res): + # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) + # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) + # array_2d_exact = array_2d_exact.squeeze().numpy() + # return array_2d_exact + # else: + return False + except Exception as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + def keypoint_cache(self, sketch): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # collection.load() + start_time = time.time() + # res = collection.query( + # expr=f"keypoint_id == {sketch['image_id']}", + # offset=0, + # limit=1, + # output_fields=["keypoint_cache", "keypoint_site"], + # ) + res = [] + logging.info(f"search keypoint time : {time.time() - start_time}") + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.save_keypoint_cache(sketch, keypoint_infer_result) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != sketch['site']: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector']) + except Exception as e: + logging.info(f"search keypoint cache milvus error {e}") + return False + + # @ RunTime + def infer_keypoint_result(self, sketch): + keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果 + return keypoint_infer_result + + @staticmethod + # @ RunTime + def save_keypoint_cache(sketch, keypoint_infer_result): + if sketch['site'] == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, keypoint_infer_result.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([keypoint_infer_result.flatten(), zeros]) + data = [ + [int(sketch['image_id'])], + [sketch['site']], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.insert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(sketch, infer_result, search_result): + if sketch['site'] == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + [int(sketch['image_id'])], + ["all"], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.upsert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) From 401b76bd95eed172e6d160966bfb4a1f0e764b3b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 30 May 2024 15:01:39 +0800 Subject: [PATCH 004/108] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20pro?= =?UTF-8?q?duct=20image=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8?= =?UTF-8?q?=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 58 ++++- app/api/api_route.py | 2 + app/api/api_slogan.py | 24 ++ app/core/config.py | 9 + app/schemas/generate_image.py | 12 + app/schemas/slogan.py | 7 + app/service/design/items/pipelines/split.py | 60 +---- app/service/design/utils/synthesis_item.py | 35 +-- app/service/design/utils/upload_image.py | 211 +++++++++--------- .../{service.py => service_generate_image.py} | 0 app/service/slogan/service.py | 36 +++ 11 files changed, 280 insertions(+), 174 deletions(-) create mode 100644 app/api/api_slogan.py create mode 100644 app/schemas/slogan.py rename app/service/generate_image/{service.py => service_generate_image.py} (100%) create mode 100644 app/service/slogan/service.py diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 78f3a66..a74eb1b 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -1,11 +1,15 @@ import logging from fastapi import APIRouter, BackgroundTasks -from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.service import GenerateImage, infer_cancel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel +from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel +from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel router = APIRouter() logger = logging.getLogger() +'''generate image''' + @router.post("/generate_image") def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): @@ -24,5 +28,53 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun @router.get("/generate_cancel/{tasks_id}>") def generate_image(tasks_id): - result = infer_cancel(tasks_id) + result = generate_image_infer_cancel(tasks_id) + return {"code": 200, "message": result['message'], "data": result['data']} + + +'''single logo''' + + +@router.post("/generate_single_logo") +def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks): + try: + logger.info(f"request data ### : {request_item}") + service = GenerateSingleLogoImage(request_item) + background_tasks.add_task(service.get_result) + code = 200 + message = "access" + except Exception as e: + code = 400 + message = e + logger.warning(e) + return {"code": code, "message": message} + + +@router.get("/generate_single_logo_cancel/{tasks_id}>") +def generate_single_logo_image(tasks_id): + result = generate_single_logo_cancel(tasks_id) + return {"code": 200, "message": result['message'], "data": result['data']} + + +'''product image''' + + +@router.post("/generate_product_image") +def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): + try: + logger.info(f"request data ### : {request_item}") + service = GenerateProductImage(request_item) + background_tasks.add_task(service.get_result) + code = 200 + message = "access" + except Exception as e: + code = 400 + message = e + logger.warning(e) + return {"code": code, "message": message} + + +@router.get("/generate_product_image_cancel_cancel/{tasks_id}>") +def generate_single_logo_image(tasks_id): + result = generate_product_image_cancel(tasks_id) return {"code": 200, "message": result['message'], "data": result['data']} diff --git a/app/api/api_route.py b/app/api/api_route.py index c2bd2d2..45ce4b3 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -8,6 +8,7 @@ from app.api import api_design from app.api import api_chat_robot from app.api import api_prompt_generation from app.api import api_design_pre_processing +from app.api import api_slogan router = APIRouter() @@ -20,3 +21,4 @@ router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") +router.include_router(api_slogan.router, tags=['slogan'], prefix="/api") diff --git a/app/api/api_slogan.py b/app/api/api_slogan.py new file mode 100644 index 0000000..31459ba --- /dev/null +++ b/app/api/api_slogan.py @@ -0,0 +1,24 @@ +import logging +import time +from fastapi import APIRouter, BackgroundTasks + +from app.schemas.slogan import SloganModel +from app.service.slogan.service import Slogan + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/slogan") +def slogan(request_item: SloganModel, background_tasks: BackgroundTasks): + try: + logger.info(f"request data ### : {request_item}") + service = Slogan(request_item) + background_tasks.add_task(service.get_result) + code = 200 + message = "access" + except Exception as e: + code = 400 + message = e + logger.warning(e) + return {"code": code, "message": message} diff --git a/app/core/config.py b/app/core/config.py index cca1de0..08802a5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -103,6 +103,15 @@ GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" +# SLOGAN service config +SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") + +# Generate Single Logo service config +GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") + +# Generate Single Logo service config +GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}") + # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' SEGMENTATION = { diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index b8f5441..b30e64e 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -8,3 +8,15 @@ class GenerateImageModel(BaseModel): mode: str category: str gender: str + + +class GenerateSingleLogoImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str + + +class GenerateProductImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str diff --git a/app/schemas/slogan.py b/app/schemas/slogan.py new file mode 100644 index 0000000..e80423d --- /dev/null +++ b/app/schemas/slogan.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class SloganModel(BaseModel): + prompt: str + svg: str + tasks_id: str diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index d800597..e46a3e1 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -2,6 +2,7 @@ import logging import cv2 import numpy as np from cv2 import cvtColor, COLOR_BGR2RGBA + from app.service.utils.generate_uuid import generate_uuid from ..builder import PIPELINES from PIL import Image @@ -45,8 +46,11 @@ class Split(object): result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) - front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"])) + front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1])) result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) + # TODO 多线程外部上传图片到minio + # result['front_mask_image'] = cv2.resize(front_mask, front_new_size) + # result['front_image'] = result_front_image_pil front_mask = cv2.resize(front_mask, front_new_size) result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) @@ -55,61 +59,19 @@ class Split(object): result_back_image[back_mask != 0] = rgba_image[back_mask != 0] result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) - back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"])) + back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1])) result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) + # TODO 多线程外部上传图片到minio + # result['back_mask_image'] = cv2.resize(back_mask, back_new_size) + # result['back_image'] = result_back_image_pil + back_mask = cv2.resize(back_mask, back_new_size) result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) else: result['back_image'] = None result["back_image_url"] = None result["back_mask_url"] = None + result['back_mask_image'] = None return result except Exception as e: logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") - - # @ RunTime - # def __call__(self, result): - # try: - # if 'mask' not in result.keys(): - # raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') - # if 'seg_result' not in result.keys(): # 没过seg模型 - # result['front_mask'] = result['mask'].copy() - # result['back_mask'] = np.zeros_like(result['mask']) - # else: - # temp_front = result['seg_result'] == 1 - # result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) - # temp_back = result['seg_result'] == 2 - # result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) - # - # if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): - # if len(result['front_mask'].shape) > 2: - # front_mask = result['front_mask'][0] - # else: - # front_mask = result['front_mask'] - # - # rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) - # result_front_image = np.zeros_like(rgba_image) - # result_front_image[front_mask != 0] = rgba_image[front_mask != 0] - # - # result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) - # front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"])) - # result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) - # front_mask = cv2.resize(front_mask, front_new_size) - # result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) - # - # if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): - # result_back_image = np.zeros_like(rgba_image) - # result_back_image[result['back_mask'] != 0] = rgba_image[result['back_mask'] != 0] - # - # result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) - # back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"])) - # result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) - # back_mask = cv2.resize(result['back_mask'], back_new_size) - # result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) - # else: - # result['back_image'] = None - # result["back_image_url"] = None - # result["back_mask_url"] = None - # return result - # except Exception as e: - # logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index 8792f7b..e5f5bd2 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -17,14 +17,15 @@ import numpy as np from PIL import Image from minio import Minio +from app.core.config import * from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid -# minio_client = Minio( -# f"{MINIO_IP}:{MINIO_PORT}", -# access_key=MINIO_ACCESS, -# secret_key=MINIO_SECRET, -# secure=MINIO_SECURE) +minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) s3 = boto3.client( 's3', @@ -130,19 +131,19 @@ def synthesis(data, size): result_image.save(output, format='PNG') data = output.getvalue() - # image_data = io.BytesIO() - # result_image.save(image_data, format='PNG') - # image_data.seek(0) - # image_bytes = image_data.read() - # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - object_name = f'result_{generate_uuid()}.png' - response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') - object_url = f"aida-results/{object_name}" - if response['ResponseMetadata']['HTTPStatusCode'] == 200: - return object_url - else: - return "" + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" except Exception as e: logging.warning(f"synthesis runtime exception : {e}") diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index f945b02..7503adc 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -33,128 +33,129 @@ s3 = boto3.client( ) -@RunTime -def upload_png_mask(front_image, object_name, mask=None): - start_time = time.time() - mask_url = None - if mask is not None: - # 反转掩模 - mask_inverted = cv2.bitwise_not(mask) - # 将掩模转换为 RGBA 格式 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # 将图像数据保存到内存中的 BytesIO 对象中 - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - image_bytes.seek(0) - try: - key = f"mask/mask_{object_name}.png" - mask_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') - except Exception as e: - print(f'上传到 S3 失败: {e}') - with io.BytesIO() as output: - front_image.save(output, format='PNG') - data = output.getvalue() - # 创建一个 S3 客户端 - try: - key = f"image/image_{object_name}.png" - image_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') - return front_image, image_url, mask_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -@RunTime -def upload_layer_image(image, object_name): - with io.BytesIO() as output: - image.save(output, format='PNG') - data = output.getvalue() - # 创建一个 S3 客户端 - try: - key = f"image/image_{object_name}.png" - image_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') - return image_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -@RunTime -def upload_mask_image(mask, object_name): - # 反转掩模 - mask_inverted = cv2.bitwise_not(mask) - # 将掩模转换为 RGBA 格式 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # 将图像数据保存到内存中的 BytesIO 对象中 - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - image_bytes.seek(0) - try: - key = f"mask/mask_{object_name}.png" - mask_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') - return mask_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -"""minio 上传""" - # @RunTime # def upload_png_mask(front_image, object_name, mask=None): # start_time = time.time() +# mask_url = None +# if mask is not None: +# # 反转掩模 +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模转换为 RGBA 格式 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# # 将图像数据保存到内存中的 BytesIO 对象中 +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# image_bytes.seek(0) +# try: +# key = f"mask/mask_{object_name}.png" +# mask_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') +# except Exception as e: +# print(f'上传到 S3 失败: {e}') +# with io.BytesIO() as output: +# front_image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 # try: -# mask_url = None -# if mask is not None: -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# -# image_bytes.seek(0) -# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" -# -# image_data = io.BytesIO() -# front_image.save(image_data, format='PNG') -# image_data.seek(0) -# image_bytes = image_data.read() -# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" -# # print(f"upload_png_mask {object_name} = {time.time() - start_time}") +# key = f"image/image_{object_name}.png" +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') # return front_image, image_url, mask_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') # # # @RunTime # def upload_layer_image(image, object_name): +# with io.BytesIO() as output: +# image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 # try: -# image_data = io.BytesIO() -# image.save(image_data, format='PNG') -# image_data.seek(0) -# image_bytes = image_data.read() -# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# key = f"image/image_{object_name}.png" +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') # return image_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') # # # @RunTime # def upload_mask_image(mask, object_name): +# # 反转掩模 +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模转换为 RGBA 格式 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# # 将图像数据保存到内存中的 BytesIO 对象中 +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# image_bytes.seek(0) # try: -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# -# image_bytes.seek(0) -# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# key = f"mask/mask_{object_name}.png" +# mask_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') # return mask_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') + + +"""minio 上传""" + + +@RunTime +def upload_png_mask(front_image, object_name, mask=None): + start_time = time.time() + try: + mask_url = None + if mask is not None: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + + image_bytes.seek(0) + mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + + image_data = io.BytesIO() + front_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # print(f"upload_png_mask {object_name} = {time.time() - start_time}") + return front_image, image_url, mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +@RunTime +def upload_layer_image(image, object_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +@RunTime +def upload_mask_image(mask, object_name): + try: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + + image_bytes.seek(0) + mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + return mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service_generate_image.py similarity index 100% rename from app/service/generate_image/service.py rename to app/service/generate_image/service_generate_image.py diff --git a/app/service/slogan/service.py b/app/service/slogan/service.py new file mode 100644 index 0000000..5a330d6 --- /dev/null +++ b/app/service/slogan/service.py @@ -0,0 +1,36 @@ +import json +import logging + +import redis + +from app.core.config import * + +logger = logging.getLogger() + + +class Slogan: + def __init__(self, request_data): + self.tasks_id = request_data.tasks_id + self.prompt = request_data.prompt + self.svg = request_data.svg + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.slogan_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.slogan_data)) + self.redis_client.expire(self.tasks_id, 600) + + # if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.result_image_url = "test/slogan/init_img.png" + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + self.slogan_data['status'] = "SUCCESS" + self.slogan_data['message'] = "success" + self.slogan_data['image_url'] = "test/slogan/init_img.png" + dict_slogan_data, str_slogan_data = self.read_tasks_status() + self.channel.basic_publish(exchange='', routing_key=SLOGAN_RABBITMQ_QUEUES, body=str_slogan_data) + logger.info(f" [x] Sent {json.dumps(dict_slogan_data, indent=4)}") From 5092a8c7bc38c089fa6e327ecb372ad4f0e951e4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 30 May 2024 15:02:35 +0800 Subject: [PATCH 005/108] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20pro?= =?UTF-8?q?duct=20image=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8?= =?UTF-8?q?=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 181 ++++++++++++++++++ .../service_generate_single_logo.py | 72 +++++++ 2 files changed, 253 insertions(+) create mode 100644 app/service/generate_image/service_generate_product_image.py create mode 100644 app/service/generate_image/service_generate_single_logo.py diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py new file mode 100644 index 0000000..ea875bd --- /dev/null +++ b/app/service/generate_image/service_generate_product_image.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time +from io import BytesIO + +import cv2 +import minio +import redis +import tritonclient.grpc as grpcclient +import numpy as np +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.utils.adjust_contrast import adjust_contrast +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd + +logger = logging.getLogger() + + +class GenerateProductImage: + def __init__(self, request_data): + # if DEBUG is False: + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + # Get data of an object. + # Read data from response. + # read image use cv2 + try: + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_file = BytesIO(response.data) + image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + image = cv2.resize(image_rbg, (1024, 1024)) + except minio.error.S3Error: + image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + is_smudge = True + if self.category == "sketch": + # 色阶调整 + cutoff = 1 + levels_img = autoLevels(image_result, cutoff) + # 亮度调整 + luminance = luminance_adjust(0.3, levels_img) + # 去背景 + remove_bg_image = remove_background(luminance) + # 人脸检测 + if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: + is_smudge = False + else: + # 污点/ + is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image + if is_smudge: # 无污点 + # image_result = adjust_contrast(image_result) + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: # 有污点 保存图片到本地 测试用 + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + # logger.info(f"stain_detection result : {self.generate_data}") + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def infer(self, inputs): + return self.grpc_client.async_infer( + model_name=GI_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) + + def get_result(self): + try: + # prompts = [self.prompt] * self.batch_size + # modes = [self.mode] * self.batch_size + # images = [self.image.astype(np.float16)] * self.batch_size + # + # text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + # mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + # image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + # + # input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + # input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") + # input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + # + # input_text.set_data_from_numpy(text_obj) + # input_image.set_data_from_numpy(image_obj) + # input_mode.set_data_from_numpy(mode_obj) + # + # inputs = [input_text, input_image, input_mode] + # ctx = self.infer(inputs) + # time_out = 600 + # generate_data = None + # while time_out > 0: + # generate_data, _ = self.read_tasks_status() + # # logger.info(generate_data) + # if generate_data['status'] in ["REVOKED", "FAILURE"]: + # ctx.cancel() + # break + # elif generate_data['status'] == "SUCCESS": + # break + # time_out -= 1 + # time.sleep(0.1) + # # logger.info(time_out, generate_data) + generate_data, _ = self.read_tasks_status() + return generate_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + # if DEBUG is False: + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateImageModel( + tasks_id="123-89", + prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + image_url="", + ) + server = GenerateImage(rd) + print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py new file mode 100644 index 0000000..0bb38a0 --- /dev/null +++ b/app/service/generate_image/service_generate_single_logo.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import redis +from minio import Minio +from app.core.config import * +from app.schemas.generate_image import GenerateSingleLogoImageModel + +logger = logging.getLogger() + + +class GenerateSingleLogoImage: + def __init__(self, request_data): + # if DEBUG is False: + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + self.redis_client.expire(self.tasks_id, 600) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + generate_data, _ = self.read_tasks_status() + return generate_data + except Exception as e: + self.gen_single_logo_data['status'] = "FAILURE" + self.gen_single_logo_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + # if DEBUG is False: + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateSingleLogoImageModel( + tasks_id="123-8", + prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + image_url="", + ) + server = GenerateSingleLogoImage(rd) + print(server.get_result()) From ae2cd25185bbb4e3621b7f985ad59fa14bdb1250 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 30 May 2024 15:51:25 +0800 Subject: [PATCH 006/108] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20pro?= =?UTF-8?q?duct=20image=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8?= =?UTF-8?q?=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1160 -> 1194 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index e3f2934da41cde5c502fe10afc3bdb9faf27eb36..4a32611647f03044c1c3554899248427435a9c74 100644 GIT binary patch delta 46 xcmeC+T*bMeib*k-A(5ekp@1QWAs Date: Thu, 30 May 2024 16:53:32 +0800 Subject: [PATCH 007/108] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20pro?= =?UTF-8?q?duct=20image=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8?= =?UTF-8?q?=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1194 -> 1246 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4a32611647f03044c1c3554899248427435a9c74..a203fe52a4c8468fb97bf5c8c626c4b12f72e398 100644 GIT binary patch delta 60 zcmZ3*d5?3$0w( Date: Thu, 30 May 2024 17:26:37 +0800 Subject: [PATCH 008/108] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20pro?= =?UTF-8?q?duct=20image=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8?= =?UTF-8?q?=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/painting.py | 39 +++++++++++-------- app/service/design/utils/synthesis_item.py | 32 +++++++-------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index b1d1ea7..0b48082 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -4,19 +4,24 @@ import boto3 import cv2 import numpy as np from PIL import Image +from minio import Minio + +from app.core.config import * from ..builder import PIPELINES -# minio_client = Minio( -# f"{MINIO_IP}:{MINIO_PORT}", -# access_key=MINIO_ACCESS, -# secret_key=MINIO_SECRET, -# secure=MINIO_SECURE) -s3 = boto3.client( - 's3', - aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", - aws_secret_access_key="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", - region_name="ap-east-1" -) +minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + +# s3 = boto3.client( +# 's3', +# aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", +# aws_secret_access_key="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", +# region_name="ap-east-1" +# ) @PIPELINES.register_module() @@ -57,8 +62,8 @@ class Painting(object): @staticmethod def get_gradient(bucket_name, object_name): - # image_data = minio_client.get_object(bucket_name, object_name) - image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] + image_data = minio_client.get_object(bucket_name, object_name) + # image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] # 从数据流中读取图像 image_bytes = image_data.read() @@ -390,8 +395,8 @@ class PrintPainting(object): if not 'IfSingle' in print_dict.keys(): print_dict['IfSingle'] = False - # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] + data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) + # data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] data_bytes = BytesIO(data.read()) image = Image.open(data_bytes) @@ -473,8 +478,8 @@ class PrintPainting(object): @staticmethod def read_image(image_url): - # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) - data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] + data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) + # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] data_bytes = BytesIO(data.read()) image = Image.open(data_bytes) diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index e5f5bd2..91505bd 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -156,20 +156,18 @@ def synthesis_single(front_image, back_image): if back_image: result_image.paste(back_image, (0, 0), back_image) - with io.BytesIO() as output: - result_image.save(output, format='PNG') - data = output.getvalue() - - # image_data = io.BytesIO() - # result_image.save(image_data, format='PNG') - # image_data.seek(0) - # image_bytes = image_data.read() - # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - - object_name = f'result_{generate_uuid()}.png' - response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') - object_url = f"aida-results/{object_name}" - if response['ResponseMetadata']['HTTPStatusCode'] == 200: - return object_url - else: - return "" + # with io.BytesIO() as output: + # result_image.save(output, format='PNG') + # data = output.getvalue() + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" From 2121f0ab46b937b66fdc787766860e2bdcac74f0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:38:20 +0800 Subject: [PATCH 009/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 8 + app/schemas/generate_image.py | 2 +- .../design/items/pipelines/painting.py | 8 +- app/service/design/utils/synthesis_item.py | 6 +- app/service/design/utils/upload_image.py | 218 +++++++++--------- .../service_generate_single_logo.py | 98 ++++++-- .../generate_image/utils/upload_sd_image.py | 34 +++ 7 files changed, 235 insertions(+), 139 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 08802a5..a5bc957 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -41,6 +41,11 @@ MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' MINIO_SECURE = True +# S3 配置 +S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ" +S3_AWS_SECRET_ACCESS_KEY = "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" +S3_REGION_NAME = "ap-east-1" + # redis 配置 REDIS_HOST = "10.1.1.240" REDIS_PORT = "6379" @@ -107,6 +112,9 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config +GSL_MODEL_URL = '10.1.1.240:10051' +GSL_MINIO_BUCKET = "aida-users" +GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index b30e64e..49cf9ce 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -13,7 +13,7 @@ class GenerateImageModel(BaseModel): class GenerateSingleLogoImageModel(BaseModel): tasks_id: str prompt: str - image_url: str + seed: str class GenerateProductImageModel(BaseModel): diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 0b48082..d1f6957 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -15,13 +15,7 @@ minio_client = Minio( secret_key=MINIO_SECRET, secure=MINIO_SECURE) - -# s3 = boto3.client( -# 's3', -# aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", -# aws_secret_access_key="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", -# region_name="ap-east-1" -# ) +s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @PIPELINES.register_module() diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index 91505bd..0cf844b 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -29,9 +29,9 @@ minio_client = Minio( s3 = boto3.client( 's3', - aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", - aws_secret_access_key="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", - region_name="ap-east-1" + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, + region_name=S3_REGION_NAME ) diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index 7503adc..70b259c 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -25,137 +25,131 @@ minio_client = Minio( secure=MINIO_SECURE) """S3 上传""" -s3 = boto3.client( - 's3', - aws_access_key_id="AKIAVD3OJIMF6UJFLSHZ", - aws_secret_access_key="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", - region_name="ap-east-1" -) +s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) -# @RunTime -# def upload_png_mask(front_image, object_name, mask=None): -# start_time = time.time() -# mask_url = None -# if mask is not None: -# # 反转掩模 -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模转换为 RGBA 格式 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# # 将图像数据保存到内存中的 BytesIO 对象中 -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# image_bytes.seek(0) -# try: -# key = f"mask/mask_{object_name}.png" -# mask_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# with io.BytesIO() as output: -# front_image.save(output, format='PNG') -# data = output.getvalue() -# # 创建一个 S3 客户端 -# try: -# key = f"image/image_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') -# return front_image, image_url, mask_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# -# -# @RunTime -# def upload_layer_image(image, object_name): -# with io.BytesIO() as output: -# image.save(output, format='PNG') -# data = output.getvalue() -# # 创建一个 S3 客户端 -# try: -# key = f"image/image_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') -# return image_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# -# -# @RunTime -# def upload_mask_image(mask, object_name): -# # 反转掩模 -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模转换为 RGBA 格式 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# # 将图像数据保存到内存中的 BytesIO 对象中 -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# image_bytes.seek(0) -# try: -# key = f"mask/mask_{object_name}.png" -# mask_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') -# return mask_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') - - -"""minio 上传""" - @RunTime def upload_png_mask(front_image, object_name, mask=None): - start_time = time.time() + mask_url = None + if mask is not None: + # 反转掩模 + mask_inverted = cv2.bitwise_not(mask) + # 将掩模转换为 RGBA 格式 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + # 将图像数据保存到内存中的 BytesIO 对象中 + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + image_bytes.seek(0) + try: + key = f"mask/mask_{object_name}.png" + mask_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') + except Exception as e: + print(f'上传到 S3 失败: {e}') + with io.BytesIO() as output: + front_image.save(output, format='PNG') + data = output.getvalue() + # 创建一个 S3 客户端 try: - mask_url = None - if mask is not None: - mask_inverted = cv2.bitwise_not(mask) - # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - - image_bytes.seek(0) - mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - - image_data = io.BytesIO() - front_image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - # print(f"upload_png_mask {object_name} = {time.time() - start_time}") + key = f"image/image_{object_name}.png" + image_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') return front_image, image_url, mask_url except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") + print(f'上传到 S3 失败: {e}') @RunTime def upload_layer_image(image, object_name): + with io.BytesIO() as output: + image.save(output, format='PNG') + data = output.getvalue() + # 创建一个 S3 客户端 try: - image_data = io.BytesIO() - image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + key = f"image/image_{object_name}.png" + image_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') return image_url except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") + print(f'上传到 S3 失败: {e}') @RunTime def upload_mask_image(mask, object_name): + # 反转掩模 + mask_inverted = cv2.bitwise_not(mask) + # 将掩模转换为 RGBA 格式 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + # 将图像数据保存到内存中的 BytesIO 对象中 + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + image_bytes.seek(0) try: - mask_inverted = cv2.bitwise_not(mask) - # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - - image_bytes.seek(0) - mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + key = f"mask/mask_{object_name}.png" + mask_url = f"{AIDA_CLOTHING}/{key}" + s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') return mask_url except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") + print(f'上传到 S3 失败: {e}') + + +"""minio 上传""" + +# @RunTime +# def upload_png_mask(front_image, object_name, mask=None): +# start_time = time.time() +# try: +# mask_url = None +# if mask is not None: +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# +# image_bytes.seek(0) +# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# +# image_data = io.BytesIO() +# front_image.save(image_data, format='PNG') +# image_data.seek(0) +# image_bytes = image_data.read() +# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# # print(f"upload_png_mask {object_name} = {time.time() - start_time}") +# return front_image, image_url, mask_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") +# +# +# @RunTime +# def upload_layer_image(image, object_name): +# try: +# image_data = io.BytesIO() +# image.save(image_data, format='PNG') +# image_data.seek(0) +# image_bytes = image_data.read() +# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# return image_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") +# +# +# @RunTime +# def upload_mask_image(mask, object_name): +# try: +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# +# image_bytes.seek(0) +# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# return mask_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index 0bb38a0..ed25d74 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -9,25 +9,39 @@ """ import json import logging +import time + +import cv2 +import numpy as np import redis +from PIL import Image from minio import Minio +from tritonclient.utils import np_to_triton_dtype + from app.core.config import * +import tritonclient.grpc as grpcclient from app.schemas.generate_image import GenerateSingleLogoImageModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo logger = logging.getLogger() class GenerateSingleLogoImage: def __init__(self, request_data): - # if DEBUG is False: - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - # self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.batch_size = 1 + self.category = "single_logo" + self.negative_prompts = "bad, ugly" + self.seed = request_data.seed self.tasks_id = request_data.tasks_id + self.prompt = request_data.prompt self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) @@ -37,20 +51,72 @@ class GenerateSingleLogoImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data + def infer(self, inputs): + return self.grpc_client.async_infer( + model_name=GSL_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) + + def callback(self, result, error): + if error: + self.gen_single_logo_data['status'] = "FAILURE" + self.gen_single_logo_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + else: + image = result.as_numpy("generated_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + image_url = upload_single_logo(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + self.gen_single_logo_data['status'] = "SUCCESS" + self.gen_single_logo_data['message'] = "success" + self.gen_single_logo_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + def get_result(self): try: - generate_data, _ = self.read_tasks_status() + # prompt + prompts = [self.prompt] * self.batch_size + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + + # negative_prompts + text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1)) + # print('text obj neg: ', text_obj_neg) + input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)) + input_text_neg.set_data_from_numpy(text_obj_neg) + + # seed + seed = np.array(self.seed, dtype="object").reshape((-1, 1)) + print('seed: ', self.seed) + input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) + input_seed.set_data_from_numpy(seed) + + inputs = [input_text, input_text_neg, input_seed] + + ctx = self.infer(inputs) + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + # logger.info(generate_data) + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + # logger.info(time_out, generate_data) return generate_data except Exception as e: - self.gen_single_logo_data['status'] = "FAILURE" - self.gen_single_logo_data['message'] = str(e) - self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() - # if DEBUG is False: - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data) + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") @@ -64,9 +130,9 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateSingleLogoImageModel( - tasks_id="123-8", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', - image_url="", + tasks_id="123-89", + prompt='an apple', + seed="1", ) server = GenerateSingleLogoImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 0e8e542..7cb7f3e 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -10,6 +10,7 @@ import io import logging +import boto3 import cv2 from PIL import Image from minio import Minio @@ -17,6 +18,39 @@ from minio import Minio from app.core.config import * minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) +s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + + +# def upload_single_logo(image, user_id, category, object_name): +# with io.BytesIO() as output: +# image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 +# try: +# key = f'{user_id}/{category}/{object_name}' +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png') +# return image_url +# except Exception as e: +# print(f'上传到 S3 失败: {e}') + +def upload_single_logo(image, user_id, category, object_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + minio_req = minio_client.put_object( + GI_MINIO_BUCKET, + f'{user_id}/{category}/{object_name}', + io.BytesIO(image_bytes), + len(image_bytes), + content_type='image/jpeg' + ) + image_url = f"aida-users/{minio_req.object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") def upload_png_sd(image, user_id, category, object_name): From a89bc375633b5c299933741e0671689c659e7488 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:42:23 +0800 Subject: [PATCH 010/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index a5bc957..2da37a3 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 6502f5ed612ac25f637ef7446260b4c6da932901 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:45:35 +0800 Subject: [PATCH 011/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 2da37a3..a5bc957 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 60e0ffca99d545b701a59455388421ae9d0509aa Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:48:24 +0800 Subject: [PATCH 012/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index a5bc957..2da37a3 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 8b2adfbd073a4d557d1bec1048b7d7c21ef928bd Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:54:52 +0800 Subject: [PATCH 013/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_single_logo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index ed25d74..11348a2 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -24,6 +24,7 @@ from app.schemas.generate_image import GenerateSingleLogoImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo logger = logging.getLogger() +logging.getLogger("pika").setLevel(logging.WARNING) class GenerateSingleLogoImage: @@ -89,7 +90,6 @@ class GenerateSingleLogoImage: # seed seed = np.array(self.seed, dtype="object").reshape((-1, 1)) - print('seed: ', self.seed) input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) input_seed.set_data_from_numpy(seed) From c15358f23dbebeff3ac13f0612a94876d149d941 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 11:57:13 +0800 Subject: [PATCH 014/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/main.py | 1 + app/service/generate_image/service_generate_single_logo.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/app/main.py b/app/main.py index 07bd258..941ff79 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,7 @@ from logging_env import LOGGER_CONFIG_DICT logging.config.dictConfig(LOGGER_CONFIG_DICT) +logging.getLogger("pika").setLevel(logging.WARNING) from starlette.middleware.cors import CORSMiddleware diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index 11348a2..cfc6902 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -24,7 +24,6 @@ from app.schemas.generate_image import GenerateSingleLogoImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo logger = logging.getLogger() -logging.getLogger("pika").setLevel(logging.WARNING) class GenerateSingleLogoImage: From a4c25c1977ae90cccbfdd6ce26bc2c10483628e4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Jun 2024 13:54:04 +0800 Subject: [PATCH 015/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_single_logo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index cfc6902..ea28099 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -114,7 +114,7 @@ class GenerateSingleLogoImage: finally: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") From 40a2e158e2877ec892c6b3d2276b68f483385e8f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Jun 2024 09:39:40 +0800 Subject: [PATCH 016/108] =?UTF-8?q?feat=20generate=20single=20logo=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_slogan.py | 24 ----------------------- app/schemas/slogan.py | 7 ------- app/service/slogan/service.py | 36 ----------------------------------- 3 files changed, 67 deletions(-) delete mode 100644 app/api/api_slogan.py delete mode 100644 app/schemas/slogan.py delete mode 100644 app/service/slogan/service.py diff --git a/app/api/api_slogan.py b/app/api/api_slogan.py deleted file mode 100644 index 31459ba..0000000 --- a/app/api/api_slogan.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging -import time -from fastapi import APIRouter, BackgroundTasks - -from app.schemas.slogan import SloganModel -from app.service.slogan.service import Slogan - -router = APIRouter() -logger = logging.getLogger() - - -@router.post("/slogan") -def slogan(request_item: SloganModel, background_tasks: BackgroundTasks): - try: - logger.info(f"request data ### : {request_item}") - service = Slogan(request_item) - background_tasks.add_task(service.get_result) - code = 200 - message = "access" - except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "message": message} diff --git a/app/schemas/slogan.py b/app/schemas/slogan.py deleted file mode 100644 index e80423d..0000000 --- a/app/schemas/slogan.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class SloganModel(BaseModel): - prompt: str - svg: str - tasks_id: str diff --git a/app/service/slogan/service.py b/app/service/slogan/service.py deleted file mode 100644 index 5a330d6..0000000 --- a/app/service/slogan/service.py +++ /dev/null @@ -1,36 +0,0 @@ -import json -import logging - -import redis - -from app.core.config import * - -logger = logging.getLogger() - - -class Slogan: - def __init__(self, request_data): - self.tasks_id = request_data.tasks_id - self.prompt = request_data.prompt - self.svg = request_data.svg - self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) - self.slogan_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} - self.redis_client.set(self.tasks_id, json.dumps(self.slogan_data)) - self.redis_client.expire(self.tasks_id, 600) - - # if DEBUG is False: - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() - self.result_image_url = "test/slogan/init_img.png" - - def read_tasks_status(self): - status_data = self.redis_client.get(self.tasks_id) - return json.loads(status_data), status_data - - def get_result(self): - self.slogan_data['status'] = "SUCCESS" - self.slogan_data['message'] = "success" - self.slogan_data['image_url'] = "test/slogan/init_img.png" - dict_slogan_data, str_slogan_data = self.read_tasks_status() - self.channel.basic_publish(exchange='', routing_key=SLOGAN_RABBITMQ_QUEUES, body=str_slogan_data) - logger.info(f" [x] Sent {json.dumps(dict_slogan_data, indent=4)}") From 1d94d485e993407eea7c75536fbf55891b4a862e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Jun 2024 15:33:34 +0800 Subject: [PATCH 017/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 5 +- app/schemas/generate_image.py | 3 - .../service_generate_product_image.py | 206 +++++++++--------- .../service_generate_single_logo.py | 6 +- .../generate_image/utils/upload_sd_image.py | 2 +- 5 files changed, 108 insertions(+), 114 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 2da37a3..62b9cdf 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -119,6 +119,9 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}") +GPI_MODEL_NAME = 'diffusion_ensemble_all' +GPI_MODEL_URL = '10.1.1.240:10061' + # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 49cf9ce..fee4a92 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -5,9 +5,6 @@ class GenerateImageModel(BaseModel): tasks_id: str prompt: str image_url: str - mode: str - category: str - gender: str class GenerateSingleLogoImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ea875bd..84f7940 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,38 +7,42 @@ @Date :2023/7/26 12:01:05 @detail : """ +import io import json import logging import time -from io import BytesIO - import cv2 -import minio import redis import tritonclient.grpc as grpcclient import numpy as np +from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast -from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() class GenerateProductImage: def __init__(self, request_data): - # if DEBUG is False: - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "product_image" + self.batch_size = 1 + self.prompt = request_data.prompt + # TODO aida design 结果图背景改为白色 + self.image, self.image_size = self.get_image(request_data.image_url) + # TODO image 填充并resize成512*768 + self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -46,63 +50,56 @@ class GenerateProductImage: self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): - # Get data of an object. - # Read data from response. - # read image use cv2 - try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image = cv2.resize(image_rbg, (1024, 1024)) - except minio.error.S3Error: - image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - return image + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_bytes = io.BytesIO(response.read()) + + # 转换为PIL图像对象 + image = Image.open(image_bytes) + target_height = 768 + target_width = 512 + + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + # image = cv2.resize(image_rbg, (1024, 1024)) + return image, image_size def callback(self, result, error): if error: - self.generate_data['status'] = "FAILURE" - self.generate_data['message'] = str(error) - # self.generate_data['data'] = str(error) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + # self.gen_product_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_image") - image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) - is_smudge = True - if self.category == "sketch": - # 色阶调整 - cutoff = 1 - levels_img = autoLevels(image_result, cutoff) - # 亮度调整 - luminance = luminance_adjust(0.3, levels_img) - # 去背景 - remove_bg_image = remove_background(luminance) - # 人脸检测 - if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: - is_smudge = False - else: - # 污点/ - is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) - # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) - self.generate_data['category'] = str(category) - image_result = not_smudge_image - if is_smudge: # 无污点 - # image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - else: # 有污点 保存图片到本地 测试用 - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - # logger.info(f"stain_detection result : {self.generate_data}") + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) @@ -110,46 +107,43 @@ class GenerateProductImage: def infer(self, inputs): return self.grpc_client.async_infer( - model_name=GI_MODEL_NAME, + model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): try: - # prompts = [self.prompt] * self.batch_size - # modes = [self.mode] * self.batch_size - # images = [self.image.astype(np.float16)] * self.batch_size - # - # text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - # mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) - # image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) - # - # input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - # input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - # input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) - # - # input_text.set_data_from_numpy(text_obj) - # input_image.set_data_from_numpy(image_obj) - # input_mode.set_data_from_numpy(mode_obj) - # - # inputs = [input_text, input_image, input_mode] - # ctx = self.infer(inputs) - # time_out = 600 - # generate_data = None - # while time_out > 0: - # generate_data, _ = self.read_tasks_status() - # # logger.info(generate_data) - # if generate_data['status'] in ["REVOKED", "FAILURE"]: - # ctx.cancel() - # break - # elif generate_data['status'] == "SUCCESS": - # break - # time_out -= 1 - # time.sleep(0.1) - # # logger.info(time_out, generate_data) - generate_data, _ = self.read_tasks_status() - return generate_data + prompts = [self.prompt] * self.batch_size + self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + self.image = cv2.resize(self.image, (512, 768)) + images = [self.image.astype(np.uint8)] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape(1) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + inputs = [input_text, input_image] + + ctx = self.infer(inputs) + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + # logger.info(gen_product_data) + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + # logger.info(time_out, gen_product_data) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data except Exception as e: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(e) @@ -157,25 +151,25 @@ class GenerateProductImage: raise Exception(str(e)) finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() - # if DEBUG is False: - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_gen_product_data) + # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} - generate_data = json.dumps(data) - redis_client.set(tasks_id, generate_data) + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', - image_url="", + prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", ) - server = GenerateImage(rd) + server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index ea28099..f3d1719 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -21,7 +21,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * import tritonclient.grpc as grpcclient from app.schemas.generate_image import GenerateSingleLogoImageModel -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image logger = logging.getLogger() @@ -67,7 +67,7 @@ class GenerateSingleLogoImage: else: image = result.as_numpy("generated_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_single_logo(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['image_url'] = str(image_url) @@ -131,7 +131,7 @@ if __name__ == '__main__': rd = GenerateSingleLogoImageModel( tasks_id="123-89", prompt='an apple', - seed="1", + seed="2", ) server = GenerateSingleLogoImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 7cb7f3e..2773aa2 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -34,7 +34,7 @@ s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S # except Exception as e: # print(f'上传到 S3 失败: {e}') -def upload_single_logo(image, user_id, category, object_name): +def upload_SDXL_image(image, user_id, category, object_name): try: image_data = io.BytesIO() image.save(image_data, format='PNG') From dedd65e01740c4e5c4c4a62bc8a24a00e0b9cc7f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Jun 2024 16:21:17 +0800 Subject: [PATCH 018/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 62b9cdf..99be6e4 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,7 +118,7 @@ GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config -GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}") +GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' From 829fb646d37e5c26b95da6bf4979e75ad3daad27 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Jun 2024 16:23:40 +0800 Subject: [PATCH 019/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 4 ++-- app/service/generate_image/service_generate_product_image.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 99be6e4..30bdd30 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -118,7 +118,7 @@ GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config -GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") +GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 84f7940..5d6908f 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -152,7 +152,7 @@ class GenerateProductImage: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_gen_product_data) + self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") From 019d6a73ee8de7baa7608c25a9ec1f23994a9492 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Jun 2024 16:25:26 +0800 Subject: [PATCH 020/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_product_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 5d6908f..ce449ea 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -154,7 +154,7 @@ class GenerateProductImage: if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") + logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): From 0544fab1356711dd2f87bbed5fff958a748bab34 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 5 Jun 2024 14:44:33 +0800 Subject: [PATCH 021/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_route.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app/api/api_route.py b/app/api/api_route.py index 45ce4b3..c2bd2d2 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -8,7 +8,6 @@ from app.api import api_design from app.api import api_chat_robot from app.api import api_prompt_generation from app.api import api_design_pre_processing -from app.api import api_slogan router = APIRouter() @@ -21,4 +20,3 @@ router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") -router.include_router(api_slogan.router, tags=['slogan'], prefix="/api") From aec03be2c98ecc1f90714e94424f505699ecbb39 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 12 Jun 2024 15:32:25 +0800 Subject: [PATCH 022/108] =?UTF-8?q?feat=20generate=20product=20image=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index fee4a92..49cf9ce 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -5,6 +5,9 @@ class GenerateImageModel(BaseModel): tasks_id: str prompt: str image_url: str + mode: str + category: str + gender: str class GenerateSingleLogoImageModel(BaseModel): From e94ea11ccda7b935207567667757d55d72bc81b6 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Jun 2024 11:17:15 +0800 Subject: [PATCH 023/108] =?UTF-8?q?feat=20fix=20=E7=BF=BB=E8=AF=91?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E4=BF=AE=E6=94=B9template?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/prompt_generation/chatgpt_for_translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index b9c2c80..71d6e4f 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -21,7 +21,7 @@ def translate_to_en(text): """You are a translation expert, proficient in various languages. And can translate various languages into English. Please translate to grammatically correct English regardless of the input language. - If the input is in English, check for grammatical errors. If there are no errors, simply output the sentence. + If the input is in English or numbers, check for grammatical errors. If there are no errors, output the input directly. If there are grammatical errors, correct them and then output the sentence.""" ) system_message_prompt = SystemMessagePromptTemplate.from_template(template) From 49807a3bf3aa1ac9fb57d77fe6abf65d99477dc3 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Jun 2024 13:35:49 +0800 Subject: [PATCH 024/108] =?UTF-8?q?feat=20fix=20=E7=BF=BB=E8=AF=91?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E4=BF=AE=E6=94=B9template?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++-- app/api/api_chat_robot.py | 2 +- app/api/api_design.py | 2 +- app/api/api_design_pre_processing.py | 2 +- app/api/api_generate_image.py | 12 ++++++------ app/api/api_prompt_generation.py | 2 +- app/api/api_super_resolution.py | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index fdecfa8..5b18ec5 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -18,7 +18,7 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() code = 200 - message = "access" + message = "OK!" logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: code = 400 @@ -35,7 +35,7 @@ def category_recognition(request_item: list[CategoryRecognitionModel]): service = CategoryRecognition(request_data=request_item) data = service.get_result() code = 200 - message = "access" + message = "OK!" logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: code = 400 diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index c394046..d919902 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -14,7 +14,7 @@ def chat_robot(request_data: ChatRobotModel): try: logger.info(f"chat_robot request item is : @@@@@@:{request_data}") code = 200 - message = "access" + message = "OK!" start_time = time.time() data = chat(post_data=request_data) logger.info(f"chat_robot Run time is @@@@@@:{time.time() - start_time}") diff --git a/app/api/api_design.py b/app/api/api_design.py index 0c48c81..79cf7e1 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -15,7 +15,7 @@ def design(request_data: DesignModel): try: logger.info(f"design request item is : @@@@@@:{request_data}") code = 200 - message = "access" + message = "OK!" start_time = time.time() data = generate(request_data=request_data) logger.info(f"design Run time is @@@@@@:{time.time() - start_time}") diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index 0c0089d..91bba6e 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -15,7 +15,7 @@ def design_pre_processing(request_data: DesignPreProcessingModel): try: logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") code = 200 - message = "access" + message = "OK!" start_time = time.time() server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index a74eb1b..9c979f1 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -18,7 +18,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun service = GenerateImage(request_item) background_tasks.add_task(service.get_result) code = 200 - message = "access" + message = "OK!" except Exception as e: code = 400 message = e @@ -29,7 +29,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun @router.get("/generate_cancel/{tasks_id}>") def generate_image(tasks_id): result = generate_image_infer_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} + return {"code": 200, "message": "OK!", "data": result['data']} '''single logo''' @@ -42,7 +42,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ service = GenerateSingleLogoImage(request_item) background_tasks.add_task(service.get_result) code = 200 - message = "access" + message = "OK!" except Exception as e: code = 400 message = e @@ -53,7 +53,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ @router.get("/generate_single_logo_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): result = generate_single_logo_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} + return {"code": 200, "message": "OK!", "data": result['data']} '''product image''' @@ -66,7 +66,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t service = GenerateProductImage(request_item) background_tasks.add_task(service.get_result) code = 200 - message = "access" + message = "OK!" except Exception as e: code = 400 message = e @@ -77,4 +77,4 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t @router.get("/generate_product_image_cancel_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): result = generate_product_image_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} + return {"code": 200, "message": "OK!", "data": result['data']} diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 5e71eec..8b930a1 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -15,7 +15,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): try: logger.info(f"prompt_translate to English request data : @@@@@@:{request_data}") code = 200 - message = "access" + message = "OK!" start_time = time.time() data = translate_to_en(request_data.text) logger.info(f"prompt_generation Run time is @@@@@@:{time.time() - start_time}") diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 63f4498..0d14ad7 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -16,7 +16,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg service = SuperResolution(request_item) background_tasks.add_task(service.sr_result) code = 200 - message = "access" + message = "OK!" except Exception as e: code = 400 message = e From 6034a3539b892ca2b552565f5e705895e7e351d8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Jun 2024 13:43:34 +0800 Subject: [PATCH 025/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++-- app/api/api_chat_robot.py | 2 +- app/api/api_design.py | 2 +- app/api/api_design_pre_processing.py | 2 +- app/api/api_generate_image.py | 12 ++++++------ app/api/api_prompt_generation.py | 2 +- app/api/api_super_resolution.py | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 5b18ec5..7dd4afb 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -25,7 +25,7 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): message = e data = e logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") - return {"code": code, "message": message, "data": data} + return {"code": code, "msg": message, "data": data} # 类别识别 @@ -42,4 +42,4 @@ def category_recognition(request_item: list[CategoryRecognitionModel]): message = e data = e logger.warning(f"category_recognition Run Exception @@@@@@:{e}") - return {"code": code, "message": message, "data": data} + return {"code": code, "msg": message, "data": data} diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index d919902..a0158da 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -24,4 +24,4 @@ def chat_robot(request_data: ChatRobotModel): data = str(e) logger.warning(f"chat_robot Run Exception @@@@@@:{e}") logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "message": message, "data": data} + return {"code": code, "msg": message, "data": data} diff --git a/app/api/api_design.py b/app/api/api_design.py index 79cf7e1..a224067 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -25,4 +25,4 @@ def design(request_data: DesignModel): data = str(e) logger.warning(f"design Run Exception @@@@@@:{e}") logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "message": message, "data": data} \ No newline at end of file + return {"code": code, "msg": message, "data": data} \ No newline at end of file diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index 91bba6e..38d051a 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -26,4 +26,4 @@ def design_pre_processing(request_data: DesignPreProcessingModel): data = str(e) logger.warning(f"design Run Exception @@@@@@:{e}") logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "message": message, "data": data} + return {"code": code, "msg": message, "data": data} diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 9c979f1..1d2f3a9 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -23,13 +23,13 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun code = 400 message = e logger.warning(e) - return {"code": code, "message": message} + return {"code": code, "msg": message} @router.get("/generate_cancel/{tasks_id}>") def generate_image(tasks_id): result = generate_image_infer_cancel(tasks_id) - return {"code": 200, "message": "OK!", "data": result['data']} + return {"code": 200, "msg": "OK!", "data": result['data']} '''single logo''' @@ -47,13 +47,13 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ code = 400 message = e logger.warning(e) - return {"code": code, "message": message} + return {"code": code, "msg": message} @router.get("/generate_single_logo_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): result = generate_single_logo_cancel(tasks_id) - return {"code": 200, "message": "OK!", "data": result['data']} + return {"code": 200, "msg": "OK!", "data": result['data']} '''product image''' @@ -71,10 +71,10 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t code = 400 message = e logger.warning(e) - return {"code": code, "message": message} + return {"code": code, "msg": message} @router.get("/generate_product_image_cancel_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): result = generate_product_image_cancel(tasks_id) - return {"code": 200, "message": "OK!", "data": result['data']} + return {"code": 200, "msg": "OK!", "data": result['data']} diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 8b930a1..a963bde 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -25,4 +25,4 @@ def prompt_generation(request_data: PromptGenerationImageModel): data = str(e) logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "message": message, "data": data} + return {"code": code, "msg": message, "data": data} diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 0d14ad7..03311d1 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -21,10 +21,10 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg code = 400 message = e logger.warning(e) - return {"code": code, "message": message} + return {"code": code, "msg": message} @router.get("/sr_cancel/{tasks_id}>") def super_resolution(tasks_id): result = infer_cancel(tasks_id) - return {"code": 200, "message": result['message'], "data": result['data']} + return {"code": 200, "msg": result['message'], "data": result['data']} From b012b91613c480193de21b865491c4521cd0980d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Jun 2024 14:31:14 +0800 Subject: [PATCH 026/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 23 ++++------ app/api/api_chat_robot.py | 16 +++---- app/api/api_design.py | 16 +++---- app/api/api_design_pre_processing.py | 19 +++----- app/api/api_generate_image.py | 67 ++++++++++++++++------------ app/api/api_prompt_generation.py | 18 +++----- app/api/api_super_resolution.py | 23 ++++++---- app/api/api_test.py | 13 ++++-- app/core/config.py | 2 +- app/main.py | 15 ++++++- app/schemas/response_template.py | 8 ++++ 11 files changed, 120 insertions(+), 100 deletions(-) create mode 100644 app/schemas/response_template.py diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 7dd4afb..267b796 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,8 +1,9 @@ import json import logging -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.attribute_retrieve import * +from app.schemas.response_template import ResponseModel from app.service.attribute.config import const from app.service.attribute.service_att_recognition import AttributeRecognition from app.service.attribute.service_category_recognition import CategoryRecognition @@ -12,34 +13,28 @@ logger = logging.getLogger() # 属性识别 -@router.post("/attribute_recognition") +@router.post("/attribute_recognition", response_model=ResponseModel) def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: + logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - code = 200 - message = "OK!" logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = e - data = e logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") - return {"code": code, "msg": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) # 类别识别 @router.post("/category_recognition") def category_recognition(request_item: list[CategoryRecognitionModel]): try: + logger.info(f"category_recognition request item is : @@@@@@:{request_item}") service = CategoryRecognition(request_data=request_item) data = service.get_result() - code = 200 - message = "OK!" logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = e - data = e logger.warning(f"category_recognition Run Exception @@@@@@:{e}") - return {"code": code, "msg": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index a0158da..dccba9a 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -1,8 +1,10 @@ +import json import logging import time -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.chat_robot import ChatRobotModel +from app.schemas.response_template import ResponseModel from app.service.chat_robot.script.main import chat router = APIRouter() @@ -13,15 +15,9 @@ logger = logging.getLogger() def chat_robot(request_data: ChatRobotModel): try: logger.info(f"chat_robot request item is : @@@@@@:{request_data}") - code = 200 - message = "OK!" - start_time = time.time() data = chat(post_data=request_data) - logger.info(f"chat_robot Run time is @@@@@@:{time.time() - start_time}") + logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = str(e) - data = str(e) logger.warning(f"chat_robot Run Exception @@@@@@:{e}") - logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "msg": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_design.py b/app/api/api_design.py index a224067..c77d4c2 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -1,9 +1,11 @@ +import json import logging import time -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.design import DesignModel +from app.schemas.response_template import ResponseModel from app.service.design.service import generate router = APIRouter() @@ -14,15 +16,9 @@ logger = logging.getLogger() def design(request_data: DesignModel): try: logger.info(f"design request item is : @@@@@@:{request_data}") - code = 200 - message = "OK!" - start_time = time.time() data = generate(request_data=request_data) - logger.info(f"design Run time is @@@@@@:{time.time() - start_time}") + logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = str(e) - data = str(e) logger.warning(f"design Run Exception @@@@@@:{e}") - logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "msg": message, "data": data} \ No newline at end of file + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index 38d051a..bd87e00 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -1,9 +1,8 @@ +import json import logging -import time - -from fastapi import APIRouter - +from fastapi import APIRouter, HTTPException from app.schemas.pre_processing import DesignPreProcessingModel +from app.schemas.response_template import ResponseModel from app.service.design_pre_processing.service import DesignPreprocessing router = APIRouter() @@ -14,16 +13,10 @@ logger = logging.getLogger() def design_pre_processing(request_data: DesignPreProcessingModel): try: logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") - code = 200 - message = "OK!" - start_time = time.time() server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) - logger.info(f"design_pre_processing Run time is @@@@@@:{time.time() - start_time}") + logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = str(e) - data = str(e) logger.warning(f"design Run Exception @@@@@@:{e}") - logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "msg": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 1d2f3a9..8c706a1 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -1,6 +1,8 @@ +import json import logging -from fastapi import APIRouter, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, HTTPException from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel +from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -14,22 +16,25 @@ logger = logging.getLogger() @router.post("/generate_image") def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): try: - logger.info(f"request data ### : {request_item}") + logger.info(f"generate_image request item is : @@@@@@:{request_item}") service = GenerateImage(request_item) background_tasks.add_task(service.get_result) - code = 200 - message = "OK!" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "msg": message} + logger.warning(f"generate_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() @router.get("/generate_cancel/{tasks_id}>") def generate_image(tasks_id): - result = generate_image_infer_cancel(tasks_id) - return {"code": 200, "msg": "OK!", "data": result['data']} + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_image_infer_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) '''single logo''' @@ -38,22 +43,25 @@ def generate_image(tasks_id): @router.post("/generate_single_logo") def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks): try: - logger.info(f"request data ### : {request_item}") + logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}") service = GenerateSingleLogoImage(request_item) background_tasks.add_task(service.get_result) - code = 200 - message = "OK!" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "msg": message} + logger.warning(f"generate_single_logo Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() @router.get("/generate_single_logo_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): - result = generate_single_logo_cancel(tasks_id) - return {"code": 200, "msg": "OK!", "data": result['data']} + try: + logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") + data = generate_single_logo_cancel(tasks_id) + logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) '''product image''' @@ -62,19 +70,22 @@ def generate_single_logo_image(tasks_id): @router.post("/generate_product_image") def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): try: - logger.info(f"request data ### : {request_item}") + logger.info(f"generate_product_image request item is : @@@@@@:{request_item}") service = GenerateProductImage(request_item) background_tasks.add_task(service.get_result) - code = 200 - message = "OK!" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "msg": message} + logger.warning(f"generate_product_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() @router.get("/generate_product_image_cancel_cancel/{tasks_id}>") def generate_single_logo_image(tasks_id): - result = generate_product_image_cancel(tasks_id) - return {"code": 200, "msg": "OK!", "data": result['data']} + try: + logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") + data = generate_single_logo_cancel(tasks_id) + logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index a963bde..292ad2e 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -1,9 +1,11 @@ +import json import logging import time -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.prompt_generation import PromptGenerationImageModel +from app.schemas.response_template import ResponseModel from app.service.prompt_generation.chatgpt_for_translation import translate_to_en router = APIRouter() @@ -13,16 +15,10 @@ logger = logging.getLogger() @router.post("/translateToEN") def prompt_generation(request_data: PromptGenerationImageModel): try: - logger.info(f"prompt_translate to English request data : @@@@@@:{request_data}") - code = 200 - message = "OK!" - start_time = time.time() + logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") data = translate_to_en(request_data.text) - logger.info(f"prompt_generation Run time is @@@@@@:{time.time() - start_time}") + logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: - code = 400 - message = str(e) - data = str(e) logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") - logger.info({"code": code, "message": message, "data": data}) - return {"code": code, "msg": message, "data": data} + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 03311d1..e68cbe0 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -1,8 +1,9 @@ import json import logging -from fastapi import APIRouter, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, HTTPException +from app.schemas.response_template import ResponseModel from app.schemas.super_resolution import SuperResolutionModel from app.service.super_resolution.service import SuperResolution, infer_cancel @@ -13,18 +14,22 @@ logger = logging.getLogger() @router.post("/super_resolution") def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks): try: + logger.info(f"super_resolution request item is : @@@@@@:{request_item}") service = SuperResolution(request_item) background_tasks.add_task(service.sr_result) - code = 200 - message = "OK!" except Exception as e: - code = 400 - message = e - logger.warning(e) - return {"code": code, "msg": message} + logger.warning(f"super_resolution Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() @router.get("/sr_cancel/{tasks_id}>") def super_resolution(tasks_id): - result = infer_cancel(tasks_id) - return {"code": 200, "msg": result['message'], "data": result['data']} + try: + logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}") + data = infer_cancel(tasks_id) + logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"sr_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_test.py b/app/api/api_test.py index 63ef1aa..739c08f 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,13 +1,20 @@ import logging from fastapi import APIRouter from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES +from fastapi import FastAPI, HTTPException + +from app.schemas.response_template import ResponseModel logger = logging.getLogger() router = APIRouter() -@router.get("") -def test(): +@router.get("{id}") +def test(id: int): logger.info(SR_RABBITMQ_QUEUES) logger.info("test") - return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES} + data = {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES} + if id == 1: + raise HTTPException(status_code=404, detail="Item not found") + + return ResponseModel(data=data) diff --git a/app/core/config.py b/app/core/config.py index 30bdd30..cfcfa0b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" diff --git a/app/main.py b/app/main.py index 941ff79..b085d7d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,13 +1,16 @@ import logging.config +from http.client import HTTPException +from fastapi.responses import JSONResponse +from fastapi import FastAPI, HTTPException, Request import uvicorn from fastapi import FastAPI from app.api.api_route import router from app.core.config import settings +from app.schemas.response_template import ResponseModel from logging_env import LOGGER_CONFIG_DICT - logging.config.dictConfig(LOGGER_CONFIG_DICT) logging.getLogger("pika").setLevel(logging.WARNING) @@ -36,5 +39,15 @@ def get_application() -> FastAPI: app = get_application() + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + return JSONResponse( + status_code=exc.status_code, + content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict() + ) + + if __name__ == '__main__': uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/schemas/response_template.py b/app/schemas/response_template.py new file mode 100644 index 0000000..b3b773c --- /dev/null +++ b/app/schemas/response_template.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel +from typing import Any, Optional + + +class ResponseModel(BaseModel): + code: int = 200 + msg: str = "OK!" + data: Optional[Any] = None From e1bc95c38077562fbd19c26488c2d88f4241d7b0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Jun 2024 17:33:53 +0800 Subject: [PATCH 027/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index cfcfa0b..30bdd30 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 5a153124002b54958b1a27529d61e109a16ac16c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 14 Jun 2024 14:52:25 +0800 Subject: [PATCH 028/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/painting.py | 2 +- app/service/generate_image/utils/image_processing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index d1f6957..aa310a3 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -161,7 +161,7 @@ class PrintPainting(object): print_background = np.array(source_image_pil) mask_background = np.array(source_image_pil_mask) - print(1) + # print(1) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index 2883129..14b9b9e 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -352,7 +352,7 @@ if __name__ == '__main__': remove_bg_img = remove_background(luminance) # cv2.imwrite("remove_bg_img.png", remove_bg_img) - print(1) + # print(1) cv2.imshow("source", img) cv2.imshow("levels", equAuto) cv2.imshow("luminance", luminance) From b3081359b7d412a8355226122f0e67708ec0d144 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 14 Jun 2024 14:52:48 +0800 Subject: [PATCH 029/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 36 +++++++++++++++++++++++++++++---- app/api/api_super_resolution.py | 2 +- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 8c706a1..07303ad 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -5,6 +5,7 @@ from app.schemas.generate_image import GenerateImageModel, GenerateProductImageM from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel +from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel router = APIRouter() @@ -26,7 +27,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun @router.get("/generate_cancel/{tasks_id}>") -def generate_image(tasks_id): +def generate_image(tasks_id: str): try: logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") data = generate_image_infer_cancel(tasks_id) @@ -53,7 +54,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ @router.get("/generate_single_logo_cancel/{tasks_id}>") -def generate_single_logo_image(tasks_id): +def generate_single_logo_image(tasks_id: str): try: logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") data = generate_single_logo_cancel(tasks_id) @@ -80,12 +81,39 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t @router.get("/generate_product_image_cancel_cancel/{tasks_id}>") -def generate_single_logo_image(tasks_id): +def generate_product_image(tasks_id: str): try: logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") - data = generate_single_logo_cancel(tasks_id) + data = generate_product_image_cancel(tasks_id) logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data['data']) + + +'''relight image''' + + +@router.post("/generate_relight_image") +def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): + try: + logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}") + service = GenerateRelightImage(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_relight_image_cancel_cancel/{tasks_id}>") +def generate_relight_image(tasks_id: str): + try: + logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}") + data = generate_relight_image_cancel(tasks_id) + logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index e68cbe0..7928309 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -24,7 +24,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg @router.get("/sr_cancel/{tasks_id}>") -def super_resolution(tasks_id): +def super_resolution(tasks_id: str): try: logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}") data = infer_cancel(tasks_id) From 756894baff6dbc9671951b9cfcc0fbe39482d90d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 10:45:45 +0800 Subject: [PATCH 030/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 4 + app/schemas/generate_image.py | 6 + .../design/items/pipelines/painting.py | 121 ++++++++++- app/service/design/items/pipelines/scale.py | 2 +- .../service_generate_relight_image.py | 202 ++++++++++++++++++ 5 files changed, 326 insertions(+), 9 deletions(-) create mode 100644 app/service/generate_image/service_generate_relight_image.py diff --git a/app/core/config.py b/app/core/config.py index 30bdd30..651dd8b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -122,6 +122,10 @@ GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProduct GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' +# Generate Single Logo service config +GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_MODEL_URL = '10.1.1.150:8001' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 49cf9ce..4f85002 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,3 +20,9 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + + +class GenerateRelightImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index aa310a3..43b42e4 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -152,16 +152,14 @@ class PrintPainting(object): rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i]) rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i]) - source_image_pil = Image.fromarray(print_background) - source_image_pil_mask = Image.fromarray(mask_background) + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) - print_background = np.array(source_image_pil) - mask_background = np.array(source_image_pil_mask) - - # print(1) + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) @@ -241,7 +239,6 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - return result else: painting_dict = {} painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] @@ -260,7 +257,113 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - return result + + if "element" in result.keys(): + print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(result['element']['element_path_list'])): + image, image_mode = self.read_image(result['element']['element_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + print(1) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + # TODO element 丢失信息 + three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) + img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) + # mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + # gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + # img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result @staticmethod def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): @@ -301,6 +404,7 @@ class PrintPainting(object): return painting_dict def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + tile = None if not trigger: tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) else: @@ -351,6 +455,7 @@ class PrintPainting(object): print_mask = result['mask'] img_fg = result['final_image'] if print_ and not painting_dict['Trigger']: + index_ = None try: index_ = len(painting_dict['location']) except: diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py index 6e0cf87..80009e1 100644 --- a/app/service/design/items/pipelines/scale.py +++ b/app/service/design/items/pipelines/scale.py @@ -25,7 +25,7 @@ class Scaling(object): # # distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) if distance_clo == 0: - result['scale'] = 10 + result['scale'] = 1 else: result['scale'] = distance_bdy / distance_clo elif result['keypoint'] == 'toe': diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py new file mode 100644 index 0000000..0eacec9 --- /dev/null +++ b/app/service/generate_image/service_generate_relight_image.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import io +import json +import logging +import time +import cv2 +import redis +import tritonclient.grpc as grpcclient +import numpy as np +from PIL import Image, ImageOps +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateRelightImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image + +logger = logging.getLogger() + + +class GenerateRelightImage: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "relight_image" + self.batch_size = 1 + self.prompt = request_data.prompt + self.seed = "12345" + # TODO aida design 结果图背景改为白色 + # self.image, self.image_size = self.get_image(request_data.image_url) + self.image = request_data.image_url + # TODO image 填充并resize成512*768 + + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_bytes = io.BytesIO(response.read()) + + # 转换为PIL图像对象 + image = Image.open(image_bytes) + target_height = 768 + target_width = 512 + + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + # image = cv2.resize(image_rbg, (1024, 1024)) + return image, image_size + + def callback(self, result, error): + if error: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + # self.gen_product_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + else: + # pil图像转成numpy数组 + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def infer(self, inputs): + return self.grpc_client.async_infer( + model_name=GRI_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) + + def get_result(self): + try: + direction = "Right Light" + negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' + prompts = [self.prompt] * self.batch_size + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + input_text = grpcclient.InferInput( + "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) + ) + input_text.set_data_from_numpy(text_obj) + + negative_prompts = [negative_prompt] * self.batch_size + text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) + input_text_neg = grpcclient.InferInput( + "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) + ) + input_text_neg.set_data_from_numpy(text_obj_neg) + + seed = np.array(self.seed, dtype="object").reshape((-1, 1)) + input_seed = grpcclient.InferInput( + "seed", seed.shape, np_to_triton_dtype(seed.dtype) + ) + input_seed.set_data_from_numpy(seed) + + input_images = [self.image] * self.batch_size + text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) + input_input_images = grpcclient.InferInput( + "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) + ) + input_input_images.set_data_from_numpy(text_obj_images) + + directions = [direction] * self.batch_size + text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) + input_directions = grpcclient.InferInput( + "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) + ) + input_directions.set_data_from_numpy(text_obj_directions) + + output_img = grpcclient.InferRequestedOutput("generated_image") + request_start = time.time() + + inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] + + ctx = self.infer(inputs) + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + # logger.info(gen_product_data) + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + # logger.info(time_out, gen_product_data) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) + return data + + +if __name__ == '__main__': + rd = GenerateRelightImageModel( + tasks_id="123-89", + prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + image_url="/workspace/i3.png", + ) + server = GenerateRelightImage(rd) + print(server.get_result()) From e604ede0391b0353307374b79eecffb685492033 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 11:00:04 +0800 Subject: [PATCH 031/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/painting.py | 4 ++-- app/service/design/utils/synthesis_item.py | 14 +++++++------- app/service/design/utils/upload_image.py | 4 ++-- .../generate_image/utils/upload_sd_image.py | 4 ++-- requirements.txt | Bin 1246 -> 1266 bytes 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 43b42e4..6d88411 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,6 +1,6 @@ import random from io import BytesIO -import boto3 +# import boto3 import cv2 import numpy as np from PIL import Image @@ -15,7 +15,7 @@ minio_client = Minio( secret_key=MINIO_SECRET, secure=MINIO_SECURE) -s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) +# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @PIPELINES.register_module() diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index 0cf844b..e6a2f25 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -11,7 +11,7 @@ import io import logging import time -import boto3 +# import boto3 import cv2 import numpy as np from PIL import Image @@ -27,12 +27,12 @@ minio_client = Minio( secret_key=MINIO_SECRET, secure=MINIO_SECURE) -s3 = boto3.client( - 's3', - aws_access_key_id=S3_ACCESS_KEY, - aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, - region_name=S3_REGION_NAME -) +# s3 = boto3.client( +# 's3', +# aws_access_key_id=S3_ACCESS_KEY, +# aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, +# region_name=S3_REGION_NAME +# ) def positioning(all_mask_shape, mask_shape, offset): diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index 70b259c..2142126 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -11,7 +11,7 @@ import io import logging import time -import boto3 +# import boto3 import cv2 from minio import Minio @@ -25,7 +25,7 @@ minio_client = Minio( secure=MINIO_SECURE) """S3 上传""" -s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) +# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 2773aa2..ec476f9 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -10,7 +10,7 @@ import io import logging -import boto3 +# import boto3 import cv2 from PIL import Image from minio import Minio @@ -18,7 +18,7 @@ from minio import Minio from app.core.config import * minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) -s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) +# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) # def upload_single_logo(image, user_id, category, object_name): diff --git a/requirements.txt b/requirements.txt index a203fe52a4c8468fb97bf5c8c626c4b12f72e398..1e4ee61b3ae308e66ad8ac8b2039262e18aa6547 100644 GIT binary patch delta 32 mcmcb|`H6GG875&{23rP020aF21`{A@$Y8`^zWFqhFcScZ{s$2N delta 12 Tcmeywd5?3$8K%t-n1q-BCl3V* From ee0e07df1c1dbfaddb9cf04158b68c92146fb3a7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 11:00:12 +0800 Subject: [PATCH 032/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1266 -> 1232 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1e4ee61b3ae308e66ad8ac8b2039262e18aa6547..7b3fa73dfc24137723e7ae7bbbff5c0ccbedbba0 100644 GIT binary patch delta 12 Tcmeywd4Y4oA*RhwnE03gCPf7D delta 42 scmcb>`H6GGAtu=*hJ1z+AU0;O1wumxJs`^jNE$L2F_>@O%OuPM0OY&~+yDRo From ce7b1bcd2396a1c7be601b05410ac73e3810034e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 11:07:23 +0800 Subject: [PATCH 033/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 739c08f..3dbfe56 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -13,7 +13,11 @@ router = APIRouter() def test(id: int): logger.info(SR_RABBITMQ_QUEUES) logger.info("test") - data = {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES} + data = { + "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, + "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, + "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, + } if id == 1: raise HTTPException(status_code=404, detail="Item not found") From 0a868d6817c9bec9cdd6a0877528456fccc3ec7b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 11:10:58 +0800 Subject: [PATCH 034/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 3dbfe56..0504349 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -11,13 +11,12 @@ router = APIRouter() @router.get("{id}") def test(id: int): - logger.info(SR_RABBITMQ_QUEUES) - logger.info("test") data = { "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, } + logger.info(data) if id == 1: raise HTTPException(status_code=404, detail="Item not found") From f2bb7a11f90da505ed44d8e55cadffc7418193f8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 13:10:46 +0800 Subject: [PATCH 035/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/service.py | 3 + app/service/design/utils/synthesis_item.py | 2 +- app/service/design/utils/upload_image.py | 213 +++++++++++---------- logging_env.py | 2 +- 4 files changed, 112 insertions(+), 108 deletions(-) diff --git a/app/service/design/service.py b/app/service/design/service.py index 372456f..0ba5e72 100644 --- a/app/service/design/service.py +++ b/app/service/design/service.py @@ -5,6 +5,8 @@ from app.service.design.utils.redis_utils import Redis from app.service.design.utils.synthesis_item import synthesis, synthesis_single import concurrent.futures +from app.service.utils.decorator import RunTime + def process_item(item, layers): # logging.info("process running.........") @@ -38,6 +40,7 @@ def final_progress(process_id): return progress +@RunTime def generate(request_data): return_response = {} request_data = request_data.dict() diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index e6a2f25..caf3fcb 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -75,7 +75,7 @@ def positioning(all_mask_shape, mask_shape, offset): return all_start, all_end, mask_start, mask_end -@RunTime +# @RunTime def synthesis(data, size): # 创建底图 base_image = Image.new('RGBA', size, (0, 0, 0, 0)) diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index 2142126..a4195f7 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -28,128 +28,129 @@ minio_client = Minio( # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - -@RunTime -def upload_png_mask(front_image, object_name, mask=None): - mask_url = None - if mask is not None: - # 反转掩模 - mask_inverted = cv2.bitwise_not(mask) - # 将掩模转换为 RGBA 格式 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # 将图像数据保存到内存中的 BytesIO 对象中 - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - image_bytes.seek(0) - try: - key = f"mask/mask_{object_name}.png" - mask_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') - except Exception as e: - print(f'上传到 S3 失败: {e}') - with io.BytesIO() as output: - front_image.save(output, format='PNG') - data = output.getvalue() - # 创建一个 S3 客户端 - try: - key = f"image/image_{object_name}.png" - image_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') - return front_image, image_url, mask_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -@RunTime -def upload_layer_image(image, object_name): - with io.BytesIO() as output: - image.save(output, format='PNG') - data = output.getvalue() - # 创建一个 S3 客户端 - try: - key = f"image/image_{object_name}.png" - image_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') - return image_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -@RunTime -def upload_mask_image(mask, object_name): - # 反转掩模 - mask_inverted = cv2.bitwise_not(mask) - # 将掩模转换为 RGBA 格式 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # 将图像数据保存到内存中的 BytesIO 对象中 - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - image_bytes.seek(0) - try: - key = f"mask/mask_{object_name}.png" - mask_url = f"{AIDA_CLOTHING}/{key}" - s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') - return mask_url - except Exception as e: - print(f'上传到 S3 失败: {e}') - - -"""minio 上传""" - +# # @RunTime # def upload_png_mask(front_image, object_name, mask=None): -# start_time = time.time() +# mask_url = None +# if mask is not None: +# # 反转掩模 +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模转换为 RGBA 格式 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# # 将图像数据保存到内存中的 BytesIO 对象中 +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# image_bytes.seek(0) +# try: +# key = f"mask/mask_{object_name}.png" +# mask_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') +# except Exception as e: +# print(f'上传到 S3 失败: {e}') +# with io.BytesIO() as output: +# front_image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 # try: -# mask_url = None -# if mask is not None: -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# -# image_bytes.seek(0) -# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" -# -# image_data = io.BytesIO() -# front_image.save(image_data, format='PNG') -# image_data.seek(0) -# image_bytes = image_data.read() -# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" -# # print(f"upload_png_mask {object_name} = {time.time() - start_time}") +# key = f"image/image_{object_name}.png" +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') # return front_image, image_url, mask_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') # # # @RunTime # def upload_layer_image(image, object_name): +# with io.BytesIO() as output: +# image.save(output, format='PNG') +# data = output.getvalue() +# # 创建一个 S3 客户端 # try: -# image_data = io.BytesIO() -# image.save(image_data, format='PNG') -# image_data.seek(0) -# image_bytes = image_data.read() -# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# key = f"image/image_{object_name}.png" +# image_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') # return image_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') # # # @RunTime # def upload_mask_image(mask, object_name): +# # 反转掩模 +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模转换为 RGBA 格式 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# # 将图像数据保存到内存中的 BytesIO 对象中 +# image_bytes = io.BytesIO() +# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# image_bytes.seek(0) # try: -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# -# image_bytes.seek(0) -# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# key = f"mask/mask_{object_name}.png" +# mask_url = f"{AIDA_CLOTHING}/{key}" +# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') # return mask_url # except Exception as e: -# logging.warning(f"upload_png_mask runtime exception : {e}") +# print(f'上传到 S3 失败: {e}') + + +"""minio 上传""" + + +# @RunTime +def upload_png_mask(front_image, object_name, mask=None): + start_time = time.time() + try: + mask_url = None + if mask is not None: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + + image_bytes.seek(0) + mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + + image_data = io.BytesIO() + front_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # print(f"upload_png_mask {object_name} = {time.time() - start_time}") + return front_image, image_url, mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +@RunTime +def upload_layer_image(image, object_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +@RunTime +def upload_mask_image(mask, object_name): + try: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + image_bytes = io.BytesIO() + image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + + image_bytes.seek(0) + mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + return mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/logging_env.py b/logging_env.py index d1ac9bc..d618e37 100644 --- a/logging_env.py +++ b/logging_env.py @@ -9,7 +9,7 @@ LOGGER_CONFIG_DICT = { "handlers": { "console": { "class": "logging.StreamHandler", - "level": "DEBUG", + "level": "INFO", "formatter": "simple", "stream": "ext://sys.stdout", }, From dd0781b9aee1b25c92028e6fbe085545cf0ffd7a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 16:42:33 +0800 Subject: [PATCH 036/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 8 ++-- .../design/items/pipelines/keypoints.py | 39 ++++++++----------- 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 651dd8b..b293cef 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -23,11 +23,11 @@ DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' # RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "-dev" # 开发环境 @@ -60,9 +60,9 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_DB_HOST = "10.1.1.240" +MILVUS_URL = "http://10.1.1.240:19530http://127.0.0.1:8000/docs#/design/design_api_design_post" +MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" -MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 4d0a081..4a9e4d1 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -14,17 +14,17 @@ class KeypointDetection(object): path here: abstract path """ - def __init__(self): - self.client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + # def __init__(self): + # self.client = MilvusClient( + # uri="http://10.1.1.240:19530", + # token="root:Milvus", + # db_name=MILVUS_ALIAS + # ) - def __del__(self): - # start_time = time.time() - self.client.close() - # print(f"client close time : {time.time() - start_time}") + # def __del__(self): + # start_time = time.time() + # self.client.close() + # print(f"client close time : {time.time() - start_time}") # @ RunTime def __call__(self, result): @@ -69,24 +69,19 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) start_time = time.time() res = client.upsert( collection_name=MILVUS_TABLE_KEYPOINT, data=data, ) # logging.info(f"save keypoint time : {time.time() - start_time}") + client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - finally: - client.close() @staticmethod def update_keypoint_cache(keypoint_id, infer_result, search_result, site): @@ -102,12 +97,9 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. @@ -125,8 +117,9 @@ class KeypointDetection(object): # @ RunTime def keypoint_cache(self, result, site): try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) keypoint_id = result['image_id'] - res = self.client.query( + res = client.query( collection_name=MILVUS_TABLE_KEYPOINT, # ids=[keypoint_id], filter=f"keypoint_id == {keypoint_id}", From e29bed20f7983e804998e9da784d15eff088e175 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 16:59:08 +0800 Subject: [PATCH 037/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 267b796..89a5e3f 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,5 +1,7 @@ import json import logging +import os + from fastapi import APIRouter, HTTPException from app.schemas.attribute_retrieve import * @@ -17,6 +19,8 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") + logger.info(const.top_description_list) + logger.info(os.getcwd()) service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") From 557e3cd1007ab9e23513066a43ca1fc571c55a72 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:02:26 +0800 Subject: [PATCH 038/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/attribute/config/const.py | 64 +++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/app/service/attribute/config/const.py b/app/service/attribute/config/const.py index 24d9412..738e486 100644 --- a/app/service/attribute/config/const.py +++ b/app/service/attribute/config/const.py @@ -1,13 +1,13 @@ -top_description_list = ['service/attribute/config/descriptor/top/length.csv', - 'service/attribute/config/descriptor/top/type.csv', - 'service/attribute/config/descriptor/top/sleeve_length.csv', - 'service/attribute/config/descriptor/top/sleeve_shape.csv', - 'service/attribute/config/descriptor/top/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/top/neckline.csv', - 'service/attribute/config/descriptor/top/design.csv', - 'service/attribute/config/descriptor/top/opening_type.csv', - 'service/attribute/config/descriptor/top/silhouette.csv', - 'service/attribute/config/descriptor/top/collar.csv'] +top_description_list = ['app/service/attribute/config/descriptor/top/length.csv', + 'app/service/attribute/config/descriptor/top/type.csv', + 'app/service/attribute/config/descriptor/top/sleeve_length.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/top/neckline.csv', + 'app/service/attribute/config/descriptor/top/design.csv', + 'app/service/attribute/config/descriptor/top/opening_type.csv', + 'app/service/attribute/config/descriptor/top/silhouette.csv', + 'app/service/attribute/config/descriptor/top/collar.csv'] top_model_list = ['attr_retrieve_T_length', 'attr_retrieve_T_type', @@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length', ] bottom_description_list = [ - 'service/attribute/config/descriptor/bottom/subtype.csv', - 'service/attribute/config/descriptor/bottom/length.csv', - 'service/attribute/config/descriptor/bottom/silhouette.csv', - 'service/attribute/config/descriptor/bottom/opening_type.csv', - 'service/attribute/config/descriptor/bottom/design.csv'] + 'app/service/attribute/config/descriptor/bottom/subtype.csv', + 'app/service/attribute/config/descriptor/bottom/length.csv', + 'app/service/attribute/config/descriptor/bottom/silhouette.csv', + 'app/service/attribute/config/descriptor/bottom/opening_type.csv', + 'app/service/attribute/config/descriptor/bottom/design.csv'] bottom_model_list = [ 'attr_retrieve_B_subtype', @@ -35,14 +35,14 @@ bottom_model_list = [ 'attr_recong_B_optype', 'attr_retrieve_B_design'] -outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shape.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/outwear/collar.csv', - 'service/attribute/config/descriptor/outwear/design.csv', - 'service/attribute/config/descriptor/outwear/opening_type.csv', - 'service/attribute/config/descriptor/outwear/silhouette.csv', ] +outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/outwear/collar.csv', + 'app/service/attribute/config/descriptor/outwear/design.csv', + 'app/service/attribute/config/descriptor/outwear/opening_type.csv', + 'app/service/attribute/config/descriptor/outwear/silhouette.csv', ] outwear_model_list = ['attr_recong_O_length', 'attr_retrieve_O_sleeve_length', @@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length', 'attr_recong_O_optype', 'attr_retrieve_O_silhouette'] -dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_shape.csv', - # 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', - 'service/attribute/config/descriptor/dress/neckline.csv', - 'service/attribute/config/descriptor/dress/collar.csv', - 'service/attribute/config/descriptor/dress/design.csv', - 'service/attribute/config/descriptor/dress/silhouette.csv', - 'service/attribute/config/descriptor/dress/type.csv'] +dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_shape.csv', + # 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/dress/neckline.csv', + 'app/service/attribute/config/descriptor/dress/collar.csv', + 'app/service/attribute/config/descriptor/dress/design.csv', + 'app/service/attribute/config/descriptor/dress/silhouette.csv', + 'app/service/attribute/config/descriptor/dress/type.csv'] dress_model_list = [ # 'attr_recong_D_length', 'attr_retrieve_D_sleeve_length', From e0d9512b26987d3917ea979c5b0a48277346ea2c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:04:20 +0800 Subject: [PATCH 039/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 89a5e3f..ef3955f 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,9 +1,6 @@ import json import logging -import os - from fastapi import APIRouter, HTTPException - from app.schemas.attribute_retrieve import * from app.schemas.response_template import ResponseModel from app.service.attribute.config import const @@ -19,8 +16,6 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") - logger.info(const.top_description_list) - logger.info(os.getcwd()) service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") From 88014bdb4c55851ffcaf76f21376073cc01de7b9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:14:10 +0800 Subject: [PATCH 040/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 11 ++++++++--- app/core/config.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index ef3955f..7a14e9d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,9 +1,11 @@ import json import logging from fastapi import APIRouter, HTTPException + +from app.core.config import DEBUG from app.schemas.attribute_retrieve import * from app.schemas.response_template import ResponseModel -from app.service.attribute.config import const +from app.service.attribute.config import const, local_debug_const from app.service.attribute.service_att_recognition import AttributeRecognition from app.service.attribute.service_category_recognition import CategoryRecognition @@ -16,13 +18,16 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") - service = AttributeRecognition(const=const, request_data=request_item) + if DEBUG: + service = AttributeRecognition(const=local_debug_const, request_data=request_item) + else: + service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) - return ResponseModel(data=data) + return ResponseModel(data={"list": data}) # 类别识别 diff --git a/app/core/config.py b/app/core/config.py index b293cef..08c0998 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 6861e89f8d2faf12bd150ceb07dd14611080107c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:14:49 +0800 Subject: [PATCH 041/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 08c0998..b293cef 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From a09476354e76b87c30eab0cb92010a04f72b51af Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:17:36 +0800 Subject: [PATCH 042/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index b293cef..0af065b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -60,7 +60,7 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_URL = "http://10.1.1.240:19530http://127.0.0.1:8000/docs#/design/design_api_design_post" +MILVUS_URL = "http://10.1.1.240:19530" MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" MILVUS_TABLE_KEYPOINT = "keypoint_cache" From a0993d7e3a4dca9541cad1d7c206e8395e13818c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:34:57 +0800 Subject: [PATCH 043/108] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/keypoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 4a9e4d1..6cf1141 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -55,7 +55,7 @@ class KeypointDetection(object): @staticmethod # @ RunTime - def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None): + def save_keypoint_cache(keypoint_id, cache, site): if site == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, cache.flatten()]) From 63db4b891798e2f4d9b1c3d6e585607885b3f2ff Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 18 Jun 2024 10:50:15 +0800 Subject: [PATCH 044/108] =?UTF-8?q?feat=20fix=20=20design=20=E8=BF=9B?= =?UTF-8?q?=E5=BA=A6=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 17 ++++++++++++++++- app/schemas/design.py | 4 ++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index c77d4c2..ecac4f5 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -4,9 +4,10 @@ import time from fastapi import APIRouter, HTTPException -from app.schemas.design import DesignModel +from app.schemas.design import DesignModel, DesignProgressModel from app.schemas.response_template import ResponseModel from app.service.design.service import generate +from app.service.design.utils.redis_utils import Redis router = APIRouter() logger = logging.getLogger() @@ -22,3 +23,17 @@ def design(request_data: DesignModel): logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +@router.post('/get_progress') +def get_progress(request_data: DesignProgressModel): + try: + logger.info(f"get_progress request item is : @@@@@@:{request_data}") + process_id = request_data.process_id + r = Redis() + data = r.read(key=process_id) + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") + except Exception as e: + logger.warning(f"design Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/schemas/design.py b/app/schemas/design.py index b203970..994deb4 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -48,3 +48,7 @@ from pydantic import BaseModel class DesignModel(BaseModel): objects: list[dict] process_id: str + + +class DesignProgressModel(BaseModel): + process_id: str From 61ae688dd60eef4d0dca52d9d34d0ec5f7566625 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 18 Jun 2024 14:31:11 +0800 Subject: [PATCH 045/108] =?UTF-8?q?feat=20fix=20=20design=20keypoint=20?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E8=AE=B0=E5=BD=95keypoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/keypoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 6cf1141..956e052 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,9 +34,9 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - keypoint_cache = self.keypoint_cache(result, site) + # keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 - # keypoint_cache = False + keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) From 8476bb3727e3ed974dad3f09d2ca3050e8e3da7e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 10:53:11 +0800 Subject: [PATCH 046/108] feat fix --- app/api/api_design.py | 4 +++- app/service/design/items/pipelines/keypoints.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index ecac4f5..cdbd1f5 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -32,8 +32,10 @@ def get_progress(request_data: DesignProgressModel): process_id = request_data.process_id r = Redis() data = r.read(key=process_id) + if data is None: + raise ValueError("The progress must be numbers ") logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") except Exception as e: - logger.warning(f"design Run Exception @@@@@@:{e}") + logger.warning(f"get_progress Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 956e052..6cf1141 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,9 +34,9 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - # keypoint_cache = self.keypoint_cache(result, site) + keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 - keypoint_cache = False + # keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) From d04c3857fcaa8e672c7b290f288081fe9d895e64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:44:04 +0800 Subject: [PATCH 047/108] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- .../service_generate_product_image.py | 9 +- .../service_generate_relight_image.py | 111 ++++++------------ 3 files changed, 41 insertions(+), 81 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 0af065b..3932bf5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -124,7 +124,7 @@ GPI_MODEL_URL = '10.1.1.240:10061' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.150:8001' # SEG service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ce449ea..2416d2c 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -20,7 +20,7 @@ from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import GenerateImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() @@ -166,10 +166,11 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - rd = GenerateImageModel( + rd = GenerateProductImageModel( tasks_id="123-89", - prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", + prompt="", + # prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 0eacec9..7c7f4b1 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -38,9 +38,10 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "12345" + self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.direction = "Right Light" # TODO aida design 结果图背景改为白色 - # self.image, self.image_size = self.get_image(request_data.image_url) - self.image = request_data.image_url + self.image = self.get_image(request_data.image_url) # TODO image 填充并resize成512*768 self.tasks_id = request_data.tasks_id @@ -51,37 +52,8 @@ class GenerateRelightImage: def get_image(self, image_url): response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size + image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1) + return image def callback(self, result, error): if error: @@ -92,7 +64,7 @@ class GenerateRelightImage: else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") @@ -114,47 +86,33 @@ class GenerateRelightImage: def get_result(self): try: - direction = "Right Light" - negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' prompts = [self.prompt] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - input_text = grpcclient.InferInput( - "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) - ) + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + seeds = [self.seed] * self.batch_size + nagetive_prompts = [self.negative_prompt] * self.batch_size + directions = [self.direction] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) - negative_prompts = [negative_prompt] * self.batch_size - text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) - input_text_neg = grpcclient.InferInput( - "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) - ) - input_text_neg.set_data_from_numpy(text_obj_neg) - - seed = np.array(self.seed, dtype="object").reshape((-1, 1)) - input_seed = grpcclient.InferInput( - "seed", seed.shape, np_to_triton_dtype(seed.dtype) - ) - input_seed.set_data_from_numpy(seed) - - input_images = [self.image] * self.batch_size - text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) - input_input_images = grpcclient.InferInput( - "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) - ) - input_input_images.set_data_from_numpy(text_obj_images) - - directions = [direction] * self.batch_size - text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) - input_directions = grpcclient.InferInput( - "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) - ) - input_directions.set_data_from_numpy(text_obj_directions) - - output_img = grpcclient.InferRequestedOutput("generated_image") - request_start = time.time() - - inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] + inputs = [input_text, input_natext, input_image, input_seed, input_direction] ctx = self.infer(inputs) time_out = 600 @@ -179,9 +137,9 @@ class GenerateRelightImage: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): @@ -195,8 +153,9 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", - prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - image_url="/workspace/i3.png", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="", + image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) print(server.get_result()) From 64e85a9c72c9208a841a7926164f65f4d00272d7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:58:21 +0800 Subject: [PATCH 048/108] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 0504349..86ed25c 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -15,6 +15,7 @@ def test(id: int): "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, + "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, } logger.info(data) if id == 1: From 20b0f81fce2e13b77e95a53b0dd000983d75044d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:59:46 +0800 Subject: [PATCH 049/108] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 3932bf5..b574845 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -123,7 +123,7 @@ GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' # Generate Single Logo service config -GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.150:8001' From d0597f4b4c6a27bbcf6d907152e869ac7256c27a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 20 Jun 2024 16:23:02 +0800 Subject: [PATCH 050/108] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 6 +- .../generate_image/service_generate_image.py | 17 ++-- .../service_generate_product_image.py | 89 +++++++------------ .../service_generate_relight_image.py | 34 ++----- .../service_generate_single_logo.py | 22 +---- .../generate_image/utils/upload_sd_image.py | 37 ++++---- app/service/utils/oss_client.py | 70 +++++++++++++++ 7 files changed, 146 insertions(+), 129 deletions(-) create mode 100644 app/service/utils/oss_client.py diff --git a/app/core/config.py b/app/core/config.py index b574845..4e74711 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" @@ -47,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8" S3_REGION_NAME = "ap-east-1" # redis 配置 -REDIS_HOST = "10.1.1.240" +REDIS_HOST = "10.1.1.150" REDIS_PORT = "6379" REDIS_DB = "2" @@ -120,7 +121,8 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' -GPI_MODEL_URL = '10.1.1.240:10061' +# GPI_MODEL_URL = '10.1.1.240:10061' +GPI_MODEL_URL = '10.1.1.150:8001' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 6f8d092..889aed7 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -25,6 +25,7 @@ from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.utils.oss_client import get_image logger = logging.getLogger() @@ -36,7 +37,7 @@ class GenerateImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": @@ -63,10 +64,13 @@ class GenerateImage: # Read data from response. # read image use cv2 try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + + image_cv2 = get_image(object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -189,7 +193,8 @@ if __name__ == '__main__': prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', - category="test" + category="test", + gender="male" ) server = GenerateImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 2416d2c..dcdf09f 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -18,10 +18,10 @@ import numpy as np from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel +from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -33,69 +33,29 @@ class GenerateProductImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.batch_size = 1 self.prompt = request_data.prompt - # TODO aida design 结果图背景改为白色 - self.image, self.image_size = self.get_image(request_data.image_url) - # TODO image 填充并resize成512*768 - + self.image, self.image_size = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -105,13 +65,6 @@ class GenerateProductImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GPI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -129,11 +82,10 @@ class GenerateProductImage: input_image.set_data_from_numpy(image_obj) inputs = [input_text, input_image] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -141,7 +93,6 @@ class GenerateProductImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -153,7 +104,6 @@ class GenerateProductImage: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") @@ -165,11 +115,36 @@ def infer_cancel(tasks_id): return data +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + + # resize 图片内尺寸 并贴到768-512的纯白图像上 + target_height = 768 + target_width = 512 + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, image_size + + if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", prompt="", - # prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 7c7f4b1..8793c42 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateRelightImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -31,43 +32,34 @@ class GenerateRelightImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "relight_image" self.batch_size = 1 self.prompt = request_data.prompt - self.seed = "12345" + self.seed = "1" self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.direction = "Right Light" - # TODO aida design 结果图背景改为白色 - self.image = self.get_image(request_data.image_url) - # TODO image 填充并resize成512*768 - + self.image_url = request_data.image_url + self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1) - return image - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -77,13 +69,6 @@ class GenerateRelightImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GRI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -114,11 +99,10 @@ class GenerateRelightImage: inputs = [input_text, input_natext, input_image, input_seed, input_direction] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -126,7 +110,6 @@ class GenerateRelightImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -138,7 +121,6 @@ class GenerateRelightImage: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") @@ -154,7 +136,7 @@ if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - prompt="", + prompt="Colorful black", image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index f3d1719..e3def3e 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -31,8 +31,6 @@ class GenerateSingleLogoImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) @@ -51,23 +49,15 @@ class GenerateSingleLogoImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GSL_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def callback(self, result, error): if error: self.gen_single_logo_data['status'] = "FAILURE" self.gen_single_logo_data['message'] = str(error) - # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) else: image = result.as_numpy("generated_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['image_url'] = str(image_url) @@ -81,25 +71,19 @@ class GenerateSingleLogoImage: input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) - # negative_prompts text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1)) - # print('text obj neg: ', text_obj_neg) input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)) input_text_neg.set_data_from_numpy(text_obj_neg) - # seed seed = np.array(self.seed, dtype="object").reshape((-1, 1)) input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) input_seed.set_data_from_numpy(seed) - inputs = [input_text, input_text_neg, input_seed] - - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() - # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -107,7 +91,6 @@ class GenerateSingleLogoImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, generate_data) return generate_data except Exception as e: raise Exception(str(e)) @@ -115,7 +98,6 @@ class GenerateSingleLogoImage: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index ec476f9..a63488c 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -16,8 +16,11 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_upload_image minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -34,36 +37,34 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET # except Exception as e: # print(f'上传到 S3 失败: {e}') -def upload_SDXL_image(image, user_id, category, object_name): +def upload_SDXL_image(image, user_id, category, file_name): try: image_data = io.BytesIO() image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(image_bytes), - len(image_bytes), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + + # minio_req = minio_client.put_object( + # GI_MINIO_BUCKET, + # f'{user_id}/{category}/{file_name}', + # io.BytesIO(image_bytes), + # len(image_bytes), + # content_type='image/jpeg' + # ) + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") -def upload_png_sd(image, user_id, category, object_name): +def upload_png_sd(image, user_id, category, file_name): try: _, img_byte_array = cv2.imencode('.jpg', image) - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(img_byte_array), - len(img_byte_array), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py new file mode 100644 index 0000000..b2d3b7d --- /dev/null +++ b/app/service/utils/oss_client.py @@ -0,0 +1,70 @@ +import io +import logging +from io import BytesIO + +import boto3 +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.core.config import * + +logger = logging.getLogger() + + +# 获取图片 +def oss_get_image(bucket, object_name, data_type): + image_object = None + + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] + + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +def oss_upload_image(bucket, object_name, image_bytes): + req = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + # url = "aida-users/89/single_logo/123-89.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + url = 'aida-users/89/relight_image/123-89.png' + read_type = "PIL" + if read_type == "cv2": + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show() From 2df1518a9957cda75b3df1bad3c30362dc50a9b7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 21 Jun 2024 17:13:39 +0800 Subject: [PATCH 051/108] feat fix minio and s3 --- app/api/api_test.py | 3 +- .../attribute/service_att_recognition.py | 11 +- .../attribute/service_category_recognition.py | 10 +- .../design/items/pipelines/keypoints.py | 8 +- app/service/design/items/pipelines/loading.py | 60 +++++---- .../design/items/pipelines/painting.py | 77 +++++++----- app/service/design/items/pipelines/split.py | 2 +- app/service/design_pre_processing/service.py | 117 ++++++++++++------ .../generate_image/service_generate_image.py | 13 +- app/service/super_resolution/service.py | 33 +++-- app/service/utils/oss_client.py | 12 +- 11 files changed, 200 insertions(+), 146 deletions(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 86ed25c..0ff521a 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -16,6 +16,7 @@ def test(id: int): "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "local_oss_server": OSS } logger.info(data) if id == 1: diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index da71c16..ddcfd1c 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -11,12 +11,12 @@ from minio import Minio import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel +from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - logging.info("实例化完成") + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( @@ -97,9 +97,10 @@ class AttributeRecognition: return res def get_image(self, url): - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index 18ee043..fb997e9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -18,12 +18,13 @@ import torch from app.core.config import * from app.schemas.attribute_retrieve import CategoryRecognitionModel +from app.service.utils.oss_client import oss_get_image class CategoryRecognition: def __init__(self, request_data): self.attr_type = pd.read_csv(CATEGORY_PATH) - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) for sketch in request_data: @@ -51,9 +52,10 @@ class CategoryRecognition: def get_image(self, url): # Get data of an object. # Read data from response. - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 6cf1141..1f53ced 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -1,5 +1,6 @@ import logging import time + import numpy as np from pymilvus import MilvusClient @@ -71,11 +72,8 @@ class KeypointDetection(object): ] try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) - start_time = time.time() - res = client.upsert( - collection_name=MILVUS_TABLE_KEYPOINT, - data=data, - ) + # start_time = time.time() + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) # logging.info(f"save keypoint time : {time.time() - start_time}") client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index 2697006..a1a49a5 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -1,6 +1,5 @@ import io import logging -import time import cv2 import numpy as np @@ -8,6 +7,7 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -70,11 +70,7 @@ class LoadImageFromFile(object): class LoadBodyImageFromFile(object): def __init__(self, body_path): self.body_path = body_path - self.minioClient = Minio( - f"{MINIO_URL}", - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png") @@ -82,33 +78,33 @@ class LoadBodyImageFromFile(object): def __call__(self, result): result["image_url"] = result['body_path'] = self.body_path result["name"] = "mannequin" - if not result['image_url'].lower().endswith(".png"): - logging.info(1) - bucket = self.body_path.split("/", 1)[0] - object_name = self.body_path.split("/", 1)[1] - new_object_name = f'{object_name[:object_name.rfind(".")]}.png' - image = self.minioClient.get_object(bucket, object_name) - image = Image.open(io.BytesIO(image.data)) - image = image.convert("RGBA") - data = image.getdata() - # - new_data = [] - for item in data: - if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: - new_data.append((255, 255, 255, 0)) - else: - new_data.append(item) - image.putdata(new_data) - image_data = io.BytesIO() - image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - self.body_path = image_path - result["image_url"] = result['body_path'] = self.body_path - response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) + # if not result['image_url'].lower().endswith(".png"): + # bucket = self.body_path.split("/", 1)[0] + # object_name = self.body_path.split("/", 1)[1] + # new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + # image = self.minioClient.get_object(bucket, object_name) + # image = Image.open(io.BytesIO(image.data)) + # image = image.convert("RGBA") + # data = image.getdata() + # # + # new_data = [] + # for item in data: + # if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + # new_data.append((255, 255, 255, 0)) + # else: + # new_data.append(item) + # image.putdata(new_data) + # image_data = io.BytesIO() + # image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # self.body_path = image_path + # result["image_url"] = result['body_path'] = self.body_path + # response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) # put_image_time = time.time() - result['body_image'] = Image.open(io.BytesIO(response.read())) + # result['body_image'] = Image.open(io.BytesIO(response.read())) + result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL") # logging.info(f"Image.open time is : {time.time() - put_image_time}") return result diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 6d88411..3c9c233 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,19 +1,16 @@ import random -from io import BytesIO + # import boto3 import cv2 import numpy as np from PIL import Image -from minio import Minio -from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES -minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + +# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -56,17 +53,18 @@ class Painting(object): @staticmethod def get_gradient(bucket_name, object_name): - image_data = minio_client.get_object(bucket_name, object_name) + # image_data = minio_client.get_object(bucket_name, object_name) # image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] # 从数据流中读取图像 - image_bytes = image_data.read() + # image_bytes = image_data.read() # 将图像数据转换为numpy数组 - image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) + # image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) # 使用OpenCV解码图像数组 - image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") return image @staticmethod @@ -494,16 +492,20 @@ class PrintPainting(object): if not 'IfSingle' in print_dict.keys(): print_dict['IfSingle'] = False - data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - # data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] + # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - new_background = Image.new('RGB', image.size, (255, 255, 255)) - new_background.paste(image, mask=image.split()[3]) + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image_pil = Image.fromarray(image_rgb) + new_background = Image.new('RGB', image_pil.size, (255, 255, 255)) + new_background.paste(image_pil, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) @@ -577,21 +579,30 @@ class PrintPainting(object): @staticmethod def read_image(image_url): - data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) - # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] - - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode - # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - # new_background = Image.new('RGB', image.size, (255, 255, 255)) - # new_background.paste(image, mask=image.split()[3]) - # image = new_background - return image, image_mode - image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" return image, image_mode + # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) + # # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] + # + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode + # # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + # if image_mode == "RGBA": + # # new_background = Image.new('RGB', image.size, (255, 255, 255)) + # # new_background.paste(image, mask=image.split()[3]) + # # image = new_background + # return image, image_mode + # image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + # return image, "RGB" + # @staticmethod # def read_image(image_url): # response = requests.get(image_url) diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index e46a3e1..0183352 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -41,7 +41,7 @@ class Split(object): else: back_mask = result['back_mask'] - rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask']) result_front_image = np.zeros_like(rgba_image) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index e655087..88ed739 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -13,15 +13,12 @@ import io from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.utils.oss_client import oss_get_image, oss_upload_image class DesignPreprocessing: - def __init__(self): - self.minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # def __init__(self): + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # @ RunTime def pipeline(self, image_list): @@ -51,8 +48,9 @@ class DesignPreprocessing: def read_image(self, image_list): for obj in image_list: - file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data - image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # 如果是四通道 mask @@ -125,7 +123,10 @@ class DesignPreprocessing: try: # 覆盖到minio image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() - self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + bucket_name = item['image_url'].split("/", 1)[0] + object_name = item['image_url'].split("/", 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") @@ -165,36 +166,76 @@ class DesignPreprocessing: # @ RunTime def composing_image(self, image_list): for image in image_list: - if image['site'] == 'down': - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] - scale = 0.4 - if waist_width / scale >= image['obj'].shape[1]: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + ''' 比例相同 整合上下装代码''' + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" else: - scale = 0.4 - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1] - if waist_width / scale >= image_width: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" + + # if image['site'] == 'down': + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" return image_list @staticmethod diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 889aed7..d193de7 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -10,22 +10,17 @@ import json import logging import time -from io import BytesIO - import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np -from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd -from app.service.utils.oss_client import get_image +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -70,7 +65,7 @@ class GenerateImage: # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image_cv2 = get_image(object_name=image_url, data_type="cv2") + image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -197,4 +192,4 @@ if __name__ == '__main__': gender="male" ) server = GenerateImage(rd) - print(server.get_result()) + print(server.get_result()) \ No newline at end of file diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index e87f1a7..f864d01 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,17 +1,15 @@ -import io +import json import logging import time -import minio.error -import redis -import json import cv2 +import minio.error import numpy as np +import redis import torch import tritonclient.grpc as grpcclient -from minio import Minio from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel -from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() @@ -24,7 +22,7 @@ class SuperResolution: self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) @@ -33,16 +31,25 @@ class SuperResolution: # @RunTime def read_image(self): try: - image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") - img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img + # try: + # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + # except minio.error.S3Error as e: + # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) + # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + # logger.info(f" [x] Sent {sr_data}") + # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") + # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 + # return img + def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") @@ -101,8 +108,10 @@ class SuperResolution: def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() - res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') - image_url = f"aida-users/{res.object_name}" + # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') + object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' + oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"{SR_MINIO_BUCKET}/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index b2d3b7d..e293117 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -15,8 +15,8 @@ logger = logging.getLogger() # 获取图片 def oss_get_image(bucket, object_name, data_type): + # cv2 默认全通道读取 image_object = None - try: if OSS == "minio": oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) @@ -24,11 +24,10 @@ def oss_get_image(bucket, object_name, data_type): else: oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] - if data_type == "cv2": image_bytes = image_data.read() image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 - image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) else: data_bytes = BytesIO(image_data.read()) image_object = Image.open(data_bytes) @@ -56,11 +55,12 @@ if __name__ == '__main__': # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" # url = "aida-sys-image/images/female/outwear/0628000054.jpg" # url = "aida-users/89/product_image/string-89.png" - # url = "aida-users/89/single_logo/123-89.png" + url = "aida-users/89/single_logo/123-89.png" # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' - url = 'aida-users/89/relight_image/123-89.png' - read_type = "PIL" + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + read_type = "cv2" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) From b007f5c11b63821c5792413d3633abd681d9a203 Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 16:57:42 +0800 Subject: [PATCH 052/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 651dd8b..3d72c1b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -112,7 +112,7 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config -GSL_MODEL_URL = '10.1.1.240:10051' +GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") @@ -120,12 +120,12 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' -GPI_MODEL_URL = '10.1.1.240:10061' +GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'stable_diffusion_1_5' -GRI_MODEL_URL = '10.1.1.150:8001' +GRI_MODEL_URL = '10.1.1.240:10041' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From a1182dab828fdec11c2a4153af68cece9f92064d Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 17:10:51 +0800 Subject: [PATCH 053/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 3d72c1b..9cc97af 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 484659122aeb71b0041dbb05ff139b049e7e5e3c Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 17:16:52 +0800 Subject: [PATCH 054/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 4e74711..96dbaad 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -113,7 +113,7 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config -GSL_MODEL_URL = '10.1.1.240:10051' +GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") @@ -121,13 +121,12 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' -# GPI_MODEL_URL = '10.1.1.240:10061' -GPI_MODEL_URL = '10.1.1.150:8001' +GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' -GRI_MODEL_URL = '10.1.1.150:8001' +GRI_MODEL_URL = '10.1.1.240:10041' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From dcfe0f71abc5a97551be2a5ddd1e6db6dec24895 Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 17:27:01 +0800 Subject: [PATCH 055/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1232 -> 1246 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7b3fa73dfc24137723e7ae7bbbff5c0ccbedbba0..68c778ce116e570ae67b442b8196e96fbd5a3079 100644 GIT binary patch delta 21 bcmcb>d5?305DRA#Lq0 Date: Sun, 23 Jun 2024 15:38:33 +0800 Subject: [PATCH 056/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 96dbaad..69fb6c2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -115,7 +115,7 @@ SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_E # Generate Single Logo service config GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" -GSL_MODEL_NAME = 'stable_diffusion_xl' +GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config From c2f1fb00c77c09e6504cb84360981ca798ba831c Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 15:45:31 +0800 Subject: [PATCH 057/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 9cc97af..dd89639 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -114,7 +114,7 @@ SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_E # Generate Single Logo service config GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" -GSL_MODEL_NAME = 'stable_diffusion_xl' +GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config From 9291e350e06ca2307567260b20541f017fa597a1 Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 16:05:44 +0800 Subject: [PATCH 058/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index dd89639..f35eb9c 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -125,7 +125,7 @@ GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'stable_diffusion_1_5' -GRI_MODEL_URL = '10.1.1.240:10041' +GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From 7266de9a484ae83f79a0190f80b835235c9b1672 Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 16:06:09 +0800 Subject: [PATCH 059/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 69fb6c2..cfc04f5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -126,7 +126,7 @@ GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' -GRI_MODEL_URL = '10.1.1.240:10041' +GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From 6c7c6b47af87f0d5aedbce28a3b1361fb6865ca1 Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 16:07:11 +0800 Subject: [PATCH 060/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1246 -> 1232 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 68c778ce116e570ae67b442b8196e96fbd5a3079..7b3fa73dfc24137723e7ae7bbbff5c0ccbedbba0 100644 GIT binary patch delta 11 Scmcb|d4Y3-5X)o*mVE#kNdx=< delta 21 bcmcb>d5?305DRA#Lq0 Date: Sun, 23 Jun 2024 16:30:18 +0800 Subject: [PATCH 061/108] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index d193de7..8d446a0 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -103,7 +103,7 @@ class GenerateImage: image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" From b0b4b5cb9115d586ef556e12b65b2173d5ea7ca0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 10:32:54 +0800 Subject: [PATCH 062/108] feat fix minio and s3 --- app/service/design/items/pipelines/split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index 0183352..e46a3e1 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -41,7 +41,7 @@ class Split(object): else: back_mask = result['back_mask'] - rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask']) + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) result_front_image = np.zeros_like(rgba_image) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] From b36f4d0a887eb8b423d925478710669343b89b61 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 11:58:28 +0800 Subject: [PATCH 063/108] =?UTF-8?q?feat=20fix=20=20=E8=B6=85=E5=88=86?= =?UTF-8?q?=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/super_resolution/service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index f864d01..557beac 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,12 +1,14 @@ import json import logging import time + import cv2 import minio.error import numpy as np import redis import torch import tritonclient.grpc as grpcclient + from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.oss_client import oss_get_image, oss_upload_image @@ -32,6 +34,7 @@ class SuperResolution: def read_image(self): try: img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") + img = img.astype(np.float32) / 255. # 解码 except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) @@ -144,6 +147,6 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") + request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="12341556") service = SuperResolution(request_data) result_url = service.sr_result() From 5077e05985feb0bd9238414d75099185a86aa38b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 14:23:35 +0800 Subject: [PATCH 064/108] =?UTF-8?q?feat=20fix=20=20=E8=B6=85=E5=88=86?= =?UTF-8?q?=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/painting.py | 41 ++----------------- app/service/super_resolution/service.py | 28 +++---------- 2 files changed, 10 insertions(+), 59 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 3c9c233..eaeafa3 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -85,39 +85,6 @@ class Painting(object): pattern[0, 0, 2] = int(R) return pattern - @staticmethod - def gradient(image, angle_degrees, start_color, end_color): - height, width = image.shape[0], image.shape[1] - - # 创建一个空白的图像 - gradient_image = np.zeros((height, width, 3), dtype=np.uint8) - - # 将角度限制在 0 到 360 度之间 - angle_degrees = np.clip(angle_degrees, 0, 360) - - # 将角度转换为弧度 - angle_radians = np.radians(angle_degrees) - - # 计算渐变的方向 - dx = np.cos(angle_radians) - dy = np.sin(angle_radians) - - # 创建网格 - x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height)) - - # 计算每个像素在渐变方向上的位置 - distance_along_gradient = (x_grid * dx + y_grid * dy) / np.sqrt(dx ** 2 + dy ** 2) - - # 计算渐变的权重 - weight = np.clip(distance_along_gradient / max(width, height), 0, 1) - - # 计算渐变的颜色 - gradient_image[:, :, 0] = (1 - weight) * start_color[0] + weight * end_color[0] - gradient_image[:, :, 1] = (1 - weight) * start_color[1] + weight * end_color[1] - gradient_image[:, :, 2] = (1 - weight) * start_color[2] + weight * end_color[2] - - return gradient_image - @PIPELINES.register_module() class PrintPainting(object): @@ -147,8 +114,8 @@ class PrintPainting(object): resized_source = image.resize(new_size) resized_source_mask = mask.resize(new_size) - rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i]) + rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i]) source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) @@ -268,8 +235,8 @@ class PrintPainting(object): resized_source = image.resize(new_size) resized_source_mask = mask.resize(new_size) - rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i]) + rotated_resized_source = resized_source.rotate(-result['element']['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-result['element']['element_angle_list'][i]) source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index 557beac..1dfe9dd 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -58,14 +58,6 @@ class SuperResolution: logging.info(f"{self.tasks_id} ===> {status_data}") return status_data - # @RunTime - def infer(self, inputs): - return self.triton_client.async_infer( - model_name=SR_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - # @RunTime def sr_result(self): sample = self.read_image() @@ -82,13 +74,16 @@ class SuperResolution: # , binary_data=True ) - ctx = self.infer(inputs) + ctx = self.triton_client.async_infer( + model_name=SR_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() - # noinspection PyTypeChecker self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break @@ -98,16 +93,6 @@ class SuperResolution: time.sleep(1) return self.read_tasks_status() - # results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) - - # sr_output = torch.from_numpy(results.as_numpy(f"output")) - # output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - # if output.ndim == 3: - # output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR - # output = (output * 255.0).round().astype(np.uint8) - # output_url = self.upload_img_sr(output, generate_uuid()) - # return output_url - def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() @@ -121,7 +106,6 @@ class SuperResolution: def callback(self, result, error): if error: - print(error) sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: @@ -147,6 +131,6 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="12341556") + request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556") service = SuperResolution(request_data) result_url = service.sr_result() From b4490ebb951b0d3254f30fc3afc9f19918c4f4c4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 15:00:04 +0800 Subject: [PATCH 065/108] =?UTF-8?q?feat=20fix=20=20design=E9=A2=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=20=E8=A1=A5=E5=81=BF=E7=99=BD=E8=BE=B9?= =?UTF-8?q?=E5=9B=BEurl=E9=81=97=E6=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_pre_processing/service.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 88ed739..f6f239a 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -4,12 +4,8 @@ import time import cv2 import numpy as np import torch -from minio import Minio -from pymilvus import connections, Collection -from urllib3.exceptions import ResponseError -import torch.nn.functional as F import tritonclient.grpc as grpcclient -import io +from urllib3.exceptions import ResponseError from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result @@ -179,14 +175,14 @@ class DesignPreprocessing: image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] - object_name = image['image_url'].split('/', 1)[1] + object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" else: image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] - object_name = image['image_url'].split('/', 1)[1] + object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" From 19c0a35bce2996774586c9f98bdb9a58bc8d2389 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 15:09:19 +0800 Subject: [PATCH 066/108] =?UTF-8?q?feat=20fix=20=20=E8=BF=9B=E5=BA=A6id?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E8=BF=94=E5=9B=9E=E6=B6=88=E6=81=AF=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index cdbd1f5..690724f 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -1,6 +1,5 @@ import json import logging -import time from fastapi import APIRouter, HTTPException @@ -33,7 +32,7 @@ def get_progress(request_data: DesignProgressModel): r = Redis() data = r.read(key=process_id) if data is None: - raise ValueError("The progress must be numbers ") + raise ValueError(f"No progress ID: {process_id}") logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") except Exception as e: logger.warning(f"get_progress Run Exception @@@@@@:{e}") From c409778cb05a27158c13378dd1cfda1d60757ca9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 15:12:01 +0800 Subject: [PATCH 067/108] =?UTF-8?q?feat=20fix=20=20=E8=BF=9B=E5=BA=A6id?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E8=BF=94=E5=9B=9E=E6=B6=88=E6=81=AF=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index 690724f..18d5f80 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -15,7 +15,7 @@ logger = logging.getLogger() @router.post("/design") def design(request_data: DesignModel): try: - logger.info(f"design request item is : @@@@@@:{request_data}") + logger.info(f"design request item is : @@@@@@:{request_data.dict()}") data = generate(request_data=request_data) logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: @@ -27,7 +27,7 @@ def design(request_data: DesignModel): @router.post('/get_progress') def get_progress(request_data: DesignProgressModel): try: - logger.info(f"get_progress request item is : @@@@@@:{request_data}") + logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}") process_id = request_data.process_id r = Redis() data = r.read(key=process_id) From 42cc2e1c519840dcf6c541ab0e1576607fd13841 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 15:25:32 +0800 Subject: [PATCH 068/108] =?UTF-8?q?feat=20fix=20=20design=20print=E9=98=B2?= =?UTF-8?q?=E6=AD=A2png=E5=9B=BE=E7=89=87=E9=80=8F=E6=98=8E=E8=BD=AC?= =?UTF-8?q?=E9=BB=91=20=E5=87=BA=E7=8E=B0=E7=9A=84bug,=E9=87=87=E7=94=A8pi?= =?UTF-8?q?l=E8=AF=BB=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/painting.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index eaeafa3..dc5bd9b 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -466,13 +466,11 @@ class PrintPainting(object): bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] object_name = print_dict['print_path_list'][0].split("/", 1)[1] - image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL") # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image.shape[2] == 4: - image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) - image_pil = Image.fromarray(image_rgb) - new_background = Image.new('RGB', image_pil.size, (255, 255, 255)) - new_background.paste(image_pil, mask=image.split()[3]) + if image.mode == "RGBA": + new_background = Image.new('RGB', image.size, (255, 255, 255)) + new_background.paste(image, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) From d827147e8a69d93d2bd1c2f6ebb2ae09e0ba87d9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 15:31:39 +0800 Subject: [PATCH 069/108] =?UTF-8?q?feat=20fix=20=20design=20ifsingle?= =?UTF-8?q?=E4=B8=BAtrue=20=E4=BD=86=20=E6=B2=A1=E6=9C=89print=E7=9A=84=20?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E6=A3=80=E6=B5=8B=E5=8F=96=E6=B6=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/painting.py | 180 +++++++++--------- 1 file changed, 89 insertions(+), 91 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index dc5bd9b..32e750c 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -99,111 +99,109 @@ class PrintPainting(object): elif result['print']["location"] == [] or result['print']["location"] is None: result['print']["location"] = [[0, 0]] if result['print']['IfSingle']: - if len(result['print']['print_path_list']) == 0: - raise ValueError('When there is no printing, ifsingle must be false') + if len(result['print']['print_path_list']) > 0: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + # print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8) + for i in range(len(result['print']['print_path_list'])): + image, image_mode = self.read_image(result['print']['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i])) - print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - # print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8) - for i in range(len(result['print']['print_path_list'])): - image, image_mode = self.read_image(result['print']['print_path_list'][i]) - if image_mode == "RGBA": - new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i])) + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) - mask = image.split()[3] - resized_source = image.resize(new_size) - resized_source_mask = mask.resize(new_size) + rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i]) - rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i]) + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) - source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) - source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) - source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) - source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) - - print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) - mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) - else: - mask = self.get_mask_inv(image) - mask = np.expand_dims(mask, axis=2) - mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) - mask = cv2.bitwise_not(mask) - # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) - x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1]) - - image_x = print_background.shape[1] - image_y = print_background.shape[0] - print_x = rotate_image.shape[1] - print_y = rotate_image.shape[0] - - # 有bug - # if x + print_x > image_x: - # rotate_image = rotate_image[:, :x + print_x - image_x] - # rotate_mask = rotate_mask[:, :x + print_x - image_x] - # # - # if y + print_y > image_y: - # rotate_image = rotate_image[:y + print_y - image_y] - # rotate_mask = rotate_mask[:y + print_y - image_y] - - # 不能是并行 - # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 - # 先挪 再判断 最后裁剪 - - # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 - if x <= 0: - rotate_image = rotate_image[:, -x:] - rotate_mask = rotate_mask[:, -x:] - start_x = x = 0 + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) else: - start_x = x + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1]) - if y <= 0: - rotate_image = rotate_image[-y:, :] - rotate_mask = rotate_mask[-y:, :] - start_y = y = 0 - else: - start_y = y + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] - # ------------------ - # 如果print-size大于image-size 则需要裁剪print + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] - if x + print_x > image_x: - rotate_image = rotate_image[:, :image_x - x] - rotate_mask = rotate_mask[:, :image_x - x] + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 - if y + print_y > image_y: - rotate_image = rotate_image[:image_y - y, :] - rotate_mask = rotate_mask[:image_y - y, :] + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image - mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) - print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + # ------------------ + # 如果print-size大于image-size 则需要裁剪print - # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) - # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] - print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) - img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) - img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) - mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) - gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) - img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) - result['final_image'] = cv2.add(img_bg, img_fg) - canvas = np.full_like(result['final_image'], 255) - temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) - tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) - temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) - tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) - result['single_image'] = cv2.add(tmp1, tmp2) + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) else: painting_dict = {} painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] From 26a56b4d9f5b17cf9e6588a6e4c93eff5ee153e8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 16:35:34 +0800 Subject: [PATCH 070/108] =?UTF-8?q?feat=20=E6=A8=A1=E7=89=B9=E9=A2=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E6=8E=A5=E5=8F=A3=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 16 +++++++++++- app/schemas/design.py | 4 +++ app/service/design/model_process_service.py | 28 +++++++++++++++++++++ app/service/utils/oss_client.py | 4 +-- 4 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 app/service/design/model_process_service.py diff --git a/app/api/api_design.py b/app/api/api_design.py index 18d5f80..ef0e085 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -3,8 +3,9 @@ import logging from fastapi import APIRouter, HTTPException -from app.schemas.design import DesignModel, DesignProgressModel +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel from app.schemas.response_template import ResponseModel +from app.service.design.model_process_service import model_transpose from app.service.design.service import generate from app.service.design.utils.redis_utils import Redis @@ -38,3 +39,16 @@ def get_progress(request_data: DesignProgressModel): logger.warning(f"get_progress Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +@router.post('/model_process') +def model_process(request_data: ModelProgressModel): + try: + logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}") + + data = model_transpose(image_path=request_data.model_path) + logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}") + except Exception as e: + logger.warning(f"model_process Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/schemas/design.py b/app/schemas/design.py index 994deb4..edcc392 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -52,3 +52,7 @@ class DesignModel(BaseModel): class DesignProgressModel(BaseModel): process_id: str + + +class ModelProgressModel(BaseModel): + model_path: str diff --git a/app/service/design/model_process_service.py b/app/service/design/model_process_service.py new file mode 100644 index 0000000..fffbd67 --- /dev/null +++ b/app/service/design/model_process_service.py @@ -0,0 +1,28 @@ +import io + +from app.service.utils.oss_client import oss_get_image, oss_upload_image + + +def model_transpose(image_path): + bucket = image_path.split("/", 1)[0] + object_name = image_path.split("/", 1)[1] + new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL") + image = image.convert("RGBA") + data = image.getdata() + # + new_data = [] + for item in data: + if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + new_data.append((255, 255, 255, 0)) + else: + new_data.append(item) + image.putdata(new_data) + + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes) + image_path = f"{bucket}/{new_object_name}" + return image_path diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index e293117..7ebeb3f 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -55,12 +55,12 @@ if __name__ == '__main__': # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" # url = "aida-sys-image/images/female/outwear/0628000054.jpg" # url = "aida-users/89/product_image/string-89.png" - url = "aida-users/89/single_logo/123-89.png" + url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" - read_type = "cv2" + read_type = "PIL " if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) From 2a83399effcb12eb7ef1c41e2352eae955af785d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 16:52:00 +0800 Subject: [PATCH 071/108] =?UTF-8?q?feat=20fix=20generate=20=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8E=A5=E5=8F=A3=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 07303ad..23fbb7f 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -1,6 +1,8 @@ import json import logging + from fastapi import APIRouter, BackgroundTasks, HTTPException + from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel @@ -26,7 +28,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun return ResponseModel() -@router.get("/generate_cancel/{tasks_id}>") +@router.get("/generate_cancel/{tasks_id}") def generate_image(tasks_id: str): try: logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") @@ -53,7 +55,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ return ResponseModel() -@router.get("/generate_single_logo_cancel/{tasks_id}>") +@router.get("/generate_single_logo_cancel/{tasks_id}") def generate_single_logo_image(tasks_id: str): try: logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") @@ -80,7 +82,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t return ResponseModel() -@router.get("/generate_product_image_cancel_cancel/{tasks_id}>") +@router.get("/generate_product_image_cancel_cancel/{tasks_id}") def generate_product_image(tasks_id: str): try: logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") @@ -107,7 +109,7 @@ def generate_relight_image(request_item: GenerateProductImageModel, background_t return ResponseModel() -@router.get("/generate_relight_image_cancel_cancel/{tasks_id}>") +@router.get("/generate_relight_image_cancel_cancel/{tasks_id}") def generate_relight_image(tasks_id: str): try: logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}") From 374edce9ef2b9a2692781b64fe9ee1ed39aee464 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 17:39:29 +0800 Subject: [PATCH 072/108] =?UTF-8?q?feat=20fix=20=E7=94=B1=E4=BA=8E?= =?UTF-8?q?=E5=BB=B6=E8=BF=9F=EF=BC=8Cbounding=20box=E5=90=8E=E7=9A=84sket?= =?UTF-8?q?ch=E4=B8=8E=E5=89=8D=E7=AB=AF=E7=BC=93=E5=AD=98=E7=9A=84sketch?= =?UTF-8?q?=E6=9C=89=E8=AF=AF=E5=B7=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_pre_processing/service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index f6f239a..272dc7e 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -9,6 +9,7 @@ from urllib3.exceptions import ResponseError from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.utils.generate_uuid import generate_uuid from app.service.utils.oss_client import oss_get_image, oss_upload_image @@ -121,8 +122,11 @@ class DesignPreprocessing: image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) bucket_name = item['image_url'].split("/", 1)[0] + # 由于延迟,bounding box后的sketch与前端缓存的sketch有误差 object_name = item['image_url'].split("/", 1)[1] - oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + new_object = f"{object_name[:object_name.rfind('/') + 1]}{generate_uuid()}.{object_name.split('.', 1)[1]}" + oss_upload_image(bucket=bucket_name, object_name=new_object, image_bytes=image_bytes) + item['new_image_url'] = f"{bucket_name}/{new_object}" print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") From b2d205d6218950ddbef2a2fea282b6d04f46e943 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 24 Jun 2024 18:02:35 +0800 Subject: [PATCH 073/108] =?UTF-8?q?feat=20fix=20=E7=94=B1=E4=BA=8E?= =?UTF-8?q?=E5=BB=B6=E8=BF=9F=EF=BC=8Cbounding=20box=E5=90=8E=E7=9A=84sket?= =?UTF-8?q?ch=E4=B8=8E=E5=89=8D=E7=AB=AF=E7=BC=93=E5=AD=98=E7=9A=84sketch?= =?UTF-8?q?=E6=9C=89=E8=AF=AF=E5=B7=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_pre_processing/service.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 272dc7e..f6f239a 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -9,7 +9,6 @@ from urllib3.exceptions import ResponseError from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result -from app.service.utils.generate_uuid import generate_uuid from app.service.utils.oss_client import oss_get_image, oss_upload_image @@ -122,11 +121,8 @@ class DesignPreprocessing: image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) bucket_name = item['image_url'].split("/", 1)[0] - # 由于延迟,bounding box后的sketch与前端缓存的sketch有误差 object_name = item['image_url'].split("/", 1)[1] - new_object = f"{object_name[:object_name.rfind('/') + 1]}{generate_uuid()}.{object_name.split('.', 1)[1]}" - oss_upload_image(bucket=bucket_name, object_name=new_object, image_bytes=image_bytes) - item['new_image_url'] = f"{bucket_name}/{new_object}" + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") From 558d86b31266be2f32acb5b7b2c5fb37500a50ce Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 25 Jun 2024 11:49:54 +0800 Subject: [PATCH 074/108] =?UTF-8?q?feat=20fix=20generate=20image2image=20o?= =?UTF-8?q?bject-name=20=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 8d446a0..f3dff16 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -65,7 +65,7 @@ class GenerateImage: # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2") + image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: From db3d86204fa28b232555ee52f4a3d86bd47ef4e3 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 25 Jun 2024 16:58:17 +0800 Subject: [PATCH 075/108] =?UTF-8?q?feat=20=E6=96=B0=E5=A2=9E=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=8F=8F=E8=BF=B0=20docs=E9=A1=B5=E9=9D=A2=20?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9ES3=20=E5=9B=BE=E7=89=87get=20upload?= =?UTF-8?q?=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E6=95=B4=E7=90=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/api/api_attribute_retrieve.py | 31 + app/api/api_chat_robot.py | 18 +- app/api/api_design.py | 120 +++ app/api/api_design_pre_processing.py | 18 + app/api/api_generate_image.py | 58 ++ app/api/api_prompt_generation.py | 10 + app/api/api_super_resolution.py | 13 + app/service/design/design_request.json | 101 --- app/service/design/design_request_2.json | 684 -------------- app/service/design/fastapi_request.json | 836 ++++++++++++++++-- app/service/design/items/pipelines/split.py | 2 - app/service/design/utils/synthesis_item.py | 12 +- app/service/design/utils/upload_image.py | 133 +-- app/service/design_pre_processing/service.py | 17 + .../generate_image/service_generate_image.py | 19 +- app/service/super_resolution/service.py | 1 - app/service/utils/oss_client.py | 10 +- 17 files changed, 1087 insertions(+), 996 deletions(-) delete mode 100644 app/service/design/design_request.json delete mode 100644 app/service/design/design_request_2.json diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 7a14e9d..d9b210c 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,5 +1,6 @@ import json import logging + from fastapi import APIRouter, HTTPException from app.core.config import DEBUG @@ -16,6 +17,22 @@ logger = logging.getLogger() # 属性识别 @router.post("/attribute_recognition", response_model=ResponseModel) def attribute_recognition(request_item: list[AttributeRecognitionModel]): + """ + 获取sketch的属性,collar sleeve_length 等等 + 创建一个具有以下参数的请求体: + - **category**: sketch的类别 ,Dress + - **colony**: 服装类别,男装或女装 + - **sketch_img_url**: 被提取属性的S3或minio url地址 + + 示例参数: + [ + { + "category": "Dress", + "colony": "Female", + "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" + } + ] + """ try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") if DEBUG: @@ -33,6 +50,20 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): # 类别识别 @router.post("/category_recognition") def category_recognition(request_item: list[CategoryRecognitionModel]): + """ + 获取sketch的类别,dress blouse 等等 + 创建一个具有以下参数的请求体: + - **colony**: 服装类别,male或Female + - **sketch_img_url**: 被提取sketch类别的S3或minio url地址 + + 示例参数: + [ + { + "colony": "Female", + "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" + } + ] + """ try: logger.info(f"category_recognition request item is : @@@@@@:{request_item}") service = CategoryRecognition(request_data=request_item) diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index dccba9a..6f3da16 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -1,6 +1,6 @@ import json import logging -import time + from fastapi import APIRouter, HTTPException from app.schemas.chat_robot import ChatRobotModel @@ -13,6 +13,22 @@ logger = logging.getLogger() @router.post("/chat_robot") def chat_robot(request_data: ChatRobotModel): + """ + 对话机器人 + 创建一个具有以下参数的请求体: + - **gender**: 服装类别 + - **message**: 消息 + - **session_id**: 会话id + - **user_id**: 用户id + + 示例参数: + { + "gender": "male", + "message": "你好", + "session_id": "string-89", + "user_id": 89 + } + """ try: logger.info(f"chat_robot request item is : @@@@@@:{request_data}") data = chat(post_data=request_data) diff --git a/app/api/api_design.py b/app/api/api_design.py index ef0e085..c056c39 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -15,6 +15,106 @@ logger = logging.getLogger() @router.post("/design") def design(request_data: DesignModel): + """ + 创建一个具有以下参数的请求体: + 示例参数: + { + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "229 214 200", + "icon": "none", + "image_id": 110205, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0916000217.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 493825, + "color": "209 125 29", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + } + ], + "process_id": "6878547032381675" + } + """ try: logger.info(f"design request item is : @@@@@@:{request_data.dict()}") data = generate(request_data=request_data) @@ -27,6 +127,16 @@ def design(request_data: DesignModel): @router.post('/get_progress') def get_progress(request_data: DesignProgressModel): + """ + 获取design 进度 + 创建一个具有以下参数的请求体: + - **process_id**: 进度id + + 示例参数: + { + "process_id": "6878547032381675" + } + """ try: logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}") process_id = request_data.process_id @@ -43,6 +153,16 @@ def get_progress(request_data: DesignProgressModel): @router.post('/model_process') def model_process(request_data: ModelProgressModel): + """ + 获取模特图片预处理 + 创建一个具有以下参数的请求体: + - **model_path**: 模特图片的minio或s3 url地址 + + 示例参数: + { + "model_path": "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.jpg" + } + """ try: logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}") diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index bd87e00..f6946dc 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -1,6 +1,8 @@ import json import logging + from fastapi import APIRouter, HTTPException + from app.schemas.pre_processing import DesignPreProcessingModel from app.schemas.response_template import ResponseModel from app.service.design_pre_processing.service import DesignPreprocessing @@ -11,6 +13,22 @@ logger = logging.getLogger() @router.post("/design_pre_processing") def design_pre_processing(request_data: DesignPreProcessingModel): + """ + design 预处理 获取sketch的基本信息 + 创建一个具有以下参数的请求体: + - **sketches**: sketch url等信息 + + 示例参数: + { + "sketches": [ + { + "image_category": "dress", + "image_id": "107903", + "image_url": "aida-sys-image/images/female/dress/0628000000.jpg" + } + ] + } + """ try: logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") server = DesignPreprocessing() diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 23fbb7f..2da5554 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -18,6 +18,25 @@ logger = logging.getLogger() @router.post("/generate_image") def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 图生图的输入,minio或S3 url 地址 + - **mode**: 生成模式,img2img或者txt2img + - **category**: 生成图片的类别,sketch print 等等 + - **gender**: 生成sketch专用,服装类别 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", + "mode": "img2img", + "category": "sketch", + "gender": "male" + } + """ try: logger.info(f"generate_image request item is : @@@@@@:{request_item}") service = GenerateImage(request_item) @@ -45,6 +64,19 @@ def generate_image(tasks_id: str): @router.post("/generate_single_logo") def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **seed**: 固定的prompt和固定的seed 每次的生成结果都是一样的 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "an apple", + "seed": "2" + } + """ try: logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}") service = GenerateSingleLogoImage(request_item) @@ -72,6 +104,19 @@ def generate_single_logo_image(tasks_id: str): @router.post("/generate_product_image") def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" + } + """ try: logger.info(f"generate_product_image request item is : @@@@@@:{request_item}") service = GenerateProductImage(request_item) @@ -99,6 +144,19 @@ def generate_product_image(tasks_id: str): @router.post("/generate_relight_image") def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + "image_url": "aida-sys-image/images/female/blouse/0628000098.jpg" + } + """ try: logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}") service = GenerateRelightImage(request_item) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 292ad2e..c7bcbcd 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -14,6 +14,16 @@ logger = logging.getLogger() @router.post("/translateToEN") def prompt_generation(request_data: PromptGenerationImageModel): + """ + 翻译prompt接口 + 创建一个具有以下参数的请求体: + - **text**: 待翻译语句 + + 示例参数: + { + "text": "你好" + } + """ try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") data = translate_to_en(request_data.text) diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 7928309..82b58f4 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -13,6 +13,19 @@ logger = logging.getLogger() @router.post("/super_resolution") def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **sr_image_url**: 超分图片的minio或s3 url地址 + - **sr_xn**: 超分的倍数,只接受2或4 + - **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果 + + 示例参数: + { + "sr_image_url": "aida-sys-image/images/female/blouse/0628000098.jpg", + "sr_xn": 2, + "sr_tasks_id": "12341556-89" + } + """ try: logger.info(f"super_resolution request item is : @@@@@@:{request_item}") service = SuperResolution(request_item) diff --git a/app/service/design/design_request.json b/app/service/design/design_request.json deleted file mode 100644 index 5551b82..0000000 --- a/app/service/design/design_request.json +++ /dev/null @@ -1,101 +0,0 @@ -{ - "objects": [ - { - "basic": { - "body_point": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 67315, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0628000325.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92912, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0825001943.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 91430, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/0825000856.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Outwear" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - ], - "process_id": "7296013643475027" -} \ No newline at end of file diff --git a/app/service/design/design_request_2.json b/app/service/design/design_request_2.json deleted file mode 100644 index 51b607a..0000000 --- a/app/service/design/design_request_2.json +++ /dev/null @@ -1,684 +0,0 @@ -{ - "objects": [ - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 67315, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0628000325.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92912, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0825001943.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 91430, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/0825000856.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Outwear" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 92913, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/dress/826000033.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Dress" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 92914, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/skirt/0902001788.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Skirt" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92915, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0902003817.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 92916, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/skirt/skirt_p4_838.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Skirt" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 84210, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0916000703.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 62041, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/0902000232.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Outwear" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 67039, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0902002591.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 78016, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/trousers_p4_302.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 92917, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0902001403.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92306, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0902001766.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 86564, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0916000038.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92918, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0628001561.jpeg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 92919, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/outwear_p3186.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Outwear" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - , - { - "basic": { - "body_point_test": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "151 78 78", - "icon": "none", - "image_id": 67009, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0902002051.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Blouse" - }, - { - "color": "151 78 78", - "icon": "none", - "image_id": 85028, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/skirt/903000142.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Skirt" - }, - { - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "resize_scale": 1.0, - "type": "Body" - } - ] - } - ], - "process_id": "7296013643475027" -} \ No newline at end of file diff --git a/app/service/design/fastapi_request.json b/app/service/design/fastapi_request.json index f578079..8c27a56 100644 --- a/app/service/design/fastapi_request.json +++ b/app/service/design/fastapi_request.json @@ -1,69 +1,771 @@ { - "basic": { - "body_point": { - "waistband_right": [ - 1081, - 1318 - ], - "hand_point_right": [ - 1200, - 1857 - ], - "waistband_left": [ - 639, - 1315 - ], - "hand_point_left": [ - 493, - 1808 - ], - "shoulder_left": [ - 556, - 582 - ], - "shoulder_right": [ - 1130, - 576 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "single", - "switch_category": "Trousers", - "body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png" - }, - "item": [ - { - "color": "151 78 78", - "image_id": "67315", - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0628000325.jpg", - "print": { - "if_single": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Trousers" - }, - { - "color": "151 78 78", - "path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png", - "image_id": 69331, - "offset": [ - 1, - 1 - ], - "print": { - "if_single": false, - "print_path_list": [] - }, - "resize_scale": 1.0, - "type": "Body" - } - ] + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493827, + "color": "127 61 21", + "elementId": 493827, + "icon": "none", + "image_id": 110201, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "27 25 23", + "icon": "none", + "image_id": 110202, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/0916000602.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "businessId": 493825, + "color": "229 214 200", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "businessId": 493824, + "color": "76 124 124", + "elementId": 493824, + "icon": "none", + "image_id": 104522, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "229 214 200", + "icon": "none", + "image_id": 110203, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0825001576.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "76 124 124", + "icon": "none", + "image_id": 96071, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/903000097.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "209 125 29", + "icon": "none", + "image_id": 93798, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493824, + "color": "209 125 29", + "elementId": 493824, + "icon": "none", + "image_id": 104522, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 110204, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902000457.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 79259, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/826000094.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "127 61 21", + "icon": "none", + "image_id": 96038, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/dress/0902003549.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493822, + "color": "127 61 21", + "elementId": 493822, + "icon": "none", + "image_id": 62309, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 493825, + "color": "118 123 115", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 493826, + "color": "127 61 21", + "elementId": 493826, + "icon": "none", + "image_id": 107105, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "118 123 115", + "icon": "none", + "image_id": 79114, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/903000169.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "229 214 200", + "icon": "none", + "image_id": 90573, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0628000541.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 336, + 264 + ], + "hand_point_right": [ + 350, + 303 + ], + "waistband_left": [ + 245, + 274 + ], + "hand_point_left": [ + 219, + 315 + ], + "shoulder_left": [ + 227, + 155 + ], + "shoulder_right": [ + 338, + 149 + ] + }, + "layer_order": false, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "229 214 200", + "icon": "none", + "image_id": 110205, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0916000217.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 493825, + "color": "209 125 29", + "elementId": 493825, + "icon": "none", + "image_id": 107101, + "offset": [ + 1, + 1 + ], + "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", + "print": { + "IfSingle": false, + "print_path_list": [] + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", + "image_id": 82966, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + } + ], + "process_id": "6878547032381675" } \ No newline at end of file diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index e46a3e1..897e49a 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -48,7 +48,6 @@ class Split(object): result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1])) result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS) - # TODO 多线程外部上传图片到minio # result['front_mask_image'] = cv2.resize(front_mask, front_new_size) # result['front_image'] = result_front_image_pil front_mask = cv2.resize(front_mask, front_new_size) @@ -61,7 +60,6 @@ class Split(object): result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1])) result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS) - # TODO 多线程外部上传图片到minio # result['back_mask_image'] = cv2.resize(back_mask, back_new_size) # result['back_image'] = result_back_image_pil diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index caf3fcb..e447018 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -9,7 +9,6 @@ """ import io import logging -import time # import boto3 import cv2 @@ -18,8 +17,8 @@ from PIL import Image from minio import Minio from app.core.config import * -from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.oss_client import oss_upload_image minio_client = Minio( MINIO_URL, @@ -27,6 +26,7 @@ minio_client = Minio( secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # s3 = boto3.client( # 's3', # aws_access_key_id=S3_ACCESS_KEY, @@ -134,8 +134,14 @@ def synthesis(data, size): image_data = io.BytesIO() result_image.save(image_data, format='PNG') image_data.seek(0) + + # oss upload image_bytes = image_data.read() - return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" # object_name = f'result_{generate_uuid()}.png' # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index a4195f7..3571816 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -9,99 +9,15 @@ """ import io import logging -import time -# import boto3 import cv2 -from minio import Minio from app.core.config import * -from app.service.utils.decorator import RunTime - -minio_client = Minio( - f"{MINIO_URL}", - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) - -"""S3 上传""" -# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - - -# -# @RunTime -# def upload_png_mask(front_image, object_name, mask=None): -# mask_url = None -# if mask is not None: -# # 反转掩模 -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模转换为 RGBA 格式 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# # 将图像数据保存到内存中的 BytesIO 对象中 -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# image_bytes.seek(0) -# try: -# key = f"mask/mask_{object_name}.png" -# mask_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# with io.BytesIO() as output: -# front_image.save(output, format='PNG') -# data = output.getvalue() -# # 创建一个 S3 客户端 -# try: -# key = f"image/image_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') -# return front_image, image_url, mask_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# -# -# @RunTime -# def upload_layer_image(image, object_name): -# with io.BytesIO() as output: -# image.save(output, format='PNG') -# data = output.getvalue() -# # 创建一个 S3 客户端 -# try: -# key = f"image/image_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png') -# return image_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') -# -# -# @RunTime -# def upload_mask_image(mask, object_name): -# # 反转掩模 -# mask_inverted = cv2.bitwise_not(mask) -# # 将掩模转换为 RGBA 格式 -# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) -# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] -# # 将图像数据保存到内存中的 BytesIO 对象中 -# image_bytes = io.BytesIO() -# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) -# image_bytes.seek(0) -# try: -# key = f"mask/mask_{object_name}.png" -# mask_url = f"{AIDA_CLOTHING}/{key}" -# s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png') -# return mask_url -# except Exception as e: -# print(f'上传到 S3 失败: {e}') - - -"""minio 上传""" +from app.service.utils.oss_client import oss_upload_image # @RunTime def upload_png_mask(front_image, object_name, mask=None): - start_time = time.time() try: mask_url = None if mask is not None: @@ -109,48 +25,21 @@ def upload_png_mask(front_image, object_name, mask=None): # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - - image_bytes.seek(0) - mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + # image_bytes = io.BytesIO() + # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) + # image_bytes.seek(0) + # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" + # oss upload #################### + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" image_data = io.BytesIO() front_image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - # print(f"upload_png_mask {object_name} = {time.time() - start_time}") + # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" return front_image, image_url, mask_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") - - -@RunTime -def upload_layer_image(image, object_name): - try: - image_data = io.BytesIO() - image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - return image_url - except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") - - -@RunTime -def upload_mask_image(mask, object_name): - try: - mask_inverted = cv2.bitwise_not(mask) - # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - image_bytes = io.BytesIO() - image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - - image_bytes.seek(0) - mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - return mask_url - except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index f6f239a..f69c3ee 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -8,6 +8,7 @@ import tritonclient.grpc as grpcclient from urllib3.exceptions import ResponseError from app.core.config import * +from app.schemas.pre_processing import DesignPreProcessingModel from app.service.design.utils.design_ensemble import get_keypoint_result from app.service.utils.oss_client import oss_get_image, oss_upload_image @@ -355,3 +356,19 @@ class DesignPreprocessing: except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + +if __name__ == '__main__': + data = { + "sketches": [ + { + "image_category": "dress", + "image_id": "107903", + "image_url": "aida-sys-image/images/female/dress/0628000000.jpg" + } + ] + } + request_data = DesignPreProcessingModel(sketches=data["sketches"]) + server = DesignPreprocessing() + data = server.pipeline(image_list=request_data.sketches) + print(data) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index f3dff16..1bc1c91 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -10,15 +10,17 @@ import json import logging import time + import cv2 import minio +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np from tritonclient.utils import np_to_triton_dtype + from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.utils.oss_client import oss_get_image @@ -120,13 +122,6 @@ class GenerateImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -146,7 +141,7 @@ class GenerateImage: input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: @@ -186,10 +181,10 @@ if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', - image_url="", + image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", gender="male" ) server = GenerateImage(rd) - print(server.get_result()) \ No newline at end of file + print(server.get_result()) diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index 1dfe9dd..c2cf39d 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -64,7 +64,6 @@ class SuperResolution: if self.sr_xn == 2: new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn) sample = cv2.resize(sample, new_shape) - print(new_shape) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 7ebeb3f..11e7911 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -44,7 +44,7 @@ def oss_upload_image(bucket, object_name, image_bytes): req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') else: oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png') + req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png') except Exception as e: logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") return req @@ -55,12 +55,16 @@ if __name__ == '__main__': # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" # url = "aida-sys-image/images/female/outwear/0628000054.jpg" # url = "aida-users/89/product_image/string-89.png" - url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" + # url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" - read_type = "PIL " + # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" + # url = "aida-users/89/single_logo/123-89.png" + # url = "aida-users/89/product_image/string-89.png" + url = "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.png" + read_type = "PIL" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) From d281d1e5c378865d4453a490b6a29fa682ffcb9f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 26 Jun 2024 11:15:24 +0800 Subject: [PATCH 076/108] =?UTF-8?q?feat=20=E6=96=B0=E5=A2=9E=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=8F=8F=E8=BF=B0=20docs=E9=A1=B5=E9=9D=A2=20?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9ES3=20=E5=9B=BE=E7=89=87get=20upload?= =?UTF-8?q?=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E6=95=B4=E7=90=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/items/bag.py | 3 ++- app/service/design/items/body.py | 1 + .../items/pipelines/contour_detection.py | 7 +++--- app/service/design/items/pipelines/loading.py | 21 +++++----------- .../design/items/pipelines/painting.py | 7 ------ app/service/design/items/pipelines/scale.py | 4 +++- app/service/design/items/pipelines/split.py | 3 ++- app/service/design/items/shoes.py | 7 +----- app/service/design/utils/conversion_image.py | 3 ++- app/service/design/utils/design_ensemble.py | 4 +++- app/service/design/utils/synthesis_item.py | 24 +++++-------------- app/service/utils/oss_client.py | 2 +- 12 files changed, 30 insertions(+), 56 deletions(-) diff --git a/app/service/design/items/bag.py b/app/service/design/items/bag.py index c171e75..12b4c68 100644 --- a/app/service/design/items/bag.py +++ b/app/service/design/items/bag.py @@ -1,6 +1,7 @@ +import random + from .builder import ITEMS from .clothing import Clothing -import random @ITEMS.register_module() diff --git a/app/service/design/items/body.py b/app/service/design/items/body.py index 69e8b36..c336ae9 100644 --- a/app/service/design/items/body.py +++ b/app/service/design/items/body.py @@ -1,4 +1,5 @@ import cv2 + from .builder import ITEMS from .pipelines import Compose diff --git a/app/service/design/items/pipelines/contour_detection.py b/app/service/design/items/pipelines/contour_detection.py index df6c7b2..018dbca 100644 --- a/app/service/design/items/pipelines/contour_detection.py +++ b/app/service/design/items/pipelines/contour_detection.py @@ -1,9 +1,8 @@ -import logging - -from ..builder import PIPELINES import cv2 import numpy as np +from ..builder import PIPELINES + @PIPELINES.register_module() class ContourDetection(object): @@ -11,7 +10,7 @@ class ContourDetection(object): # logging.info("ContourDetection run ") pass - #@ RunTime + # @ RunTime def __call__(self, result): # shoe diff if result['name'] == 'shoes': diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index a1a49a5..d792646 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -1,12 +1,5 @@ -import io -import logging - import cv2 -import numpy as np -from PIL import Image -from minio import Minio -from app.core.config import * from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -17,11 +10,7 @@ class LoadImageFromFile(object): self.path = path self.color = color self.print_dict = print_dict - self.minio_client = Minio( - f"{MINIO_URL}", - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(self.path) @@ -53,11 +42,13 @@ class LoadImageFromFile(object): f"bag, shoes, hairstyle, earring.") return keypoint - def read_image(self, image_path): + @staticmethod + def read_image(image_path): image_mask = None - file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data - image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + # file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) if image.shape[2] == 4: # 如果是四通道 mask diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 32e750c..21b567f 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,6 +1,5 @@ import random -# import boto3 import cv2 import numpy as np from PIL import Image @@ -9,12 +8,6 @@ from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES -# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - - -# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - - @PIPELINES.register_module() class Painting(object): def __init__(self, painting_flag=True): diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py index 80009e1..d101530 100644 --- a/app/service/design/items/pipelines/scale.py +++ b/app/service/design/items/pipelines/scale.py @@ -1,7 +1,9 @@ -from ..builder import PIPELINES import math + import cv2 +from ..builder import PIPELINES + @PIPELINES.register_module() class Scaling(object): diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index 897e49a..155347a 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -1,11 +1,12 @@ import logging + import cv2 import numpy as np +from PIL import Image from cv2 import cvtColor, COLOR_BGR2RGBA from app.service.utils.generate_uuid import generate_uuid from ..builder import PIPELINES -from PIL import Image from ...utils.conversion_image import rgb_to_rgba from ...utils.upload_image import upload_png_mask diff --git a/app/service/design/items/shoes.py b/app/service/design/items/shoes.py index f4e17f2..aa20d3c 100644 --- a/app/service/design/items/shoes.py +++ b/app/service/design/items/shoes.py @@ -1,14 +1,9 @@ -import io -import logging -import time - import cv2 import numpy as np +from PIL import Image from .builder import ITEMS from .clothing import Clothing -from PIL import Image - from ..utils.conversion_image import rgb_to_rgba from ..utils.upload_image import upload_png_mask from ...utils.generate_uuid import generate_uuid diff --git a/app/service/design/utils/conversion_image.py b/app/service/design/utils/conversion_image.py index 0915070..77848cc 100644 --- a/app/service/design/utils/conversion_image.py +++ b/app/service/design/utils/conversion_image.py @@ -19,5 +19,6 @@ def rgb_to_rgba(rgb_size, rgb_image, mask): rgba_image[:, :, 3] = alpha_channel return rgba_image + if __name__ == '__main__': - image = open("") \ No newline at end of file + image = open("") diff --git a/app/service/design/utils/design_ensemble.py b/app/service/design/utils/design_ensemble.py index a1021e9..00d391f 100644 --- a/app/service/design/utils/design_ensemble.py +++ b/app/service/design/utils/design_ensemble.py @@ -8,12 +8,14 @@ @detail :发起请求 获取推理结果 """ import logging + import cv2 import mmcv import numpy as np -import tritonclient.http as httpclient import torch import torch.nn.functional as F +import tritonclient.http as httpclient + from app.core.config import * """ diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index e447018..dc8e427 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -10,30 +10,13 @@ import io import logging -# import boto3 import cv2 import numpy as np from PIL import Image -from minio import Minio -from app.core.config import * from app.service.utils.generate_uuid import generate_uuid from app.service.utils.oss_client import oss_upload_image -minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) - - -# s3 = boto3.client( -# 's3', -# aws_access_key_id=S3_ACCESS_KEY, -# aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, -# region_name=S3_REGION_NAME -# ) - def positioning(all_mask_shape, mask_shape, offset): all_start = 0 @@ -176,4 +159,9 @@ def synthesis_single(front_image, back_image): result_image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # oss upload + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 11e7911..653d49d 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -63,7 +63,7 @@ if __name__ == '__main__': # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/product_image/string-89.png" - url = "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.png" + url = "aida-clothing/mask/mask_773e270b-3369-11ef-abe4-b0dcefbff887.png" read_type = "PIL" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) From 848e8bc5cbb0e20306e7ab4e01e64697faf66914 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 26 Jun 2024 11:15:56 +0800 Subject: [PATCH 077/108] =?UTF-8?q?feat=20=E6=96=B0=E5=A2=9E=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=8F=8F=E8=BF=B0=20docs=E9=A1=B5=E9=9D=A2=20?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9ES3=20=E5=9B=BE=E7=89=87get=20upload?= =?UTF-8?q?=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E6=95=B4=E7=90=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/api/api_design.py | 92 ++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 53 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index c056c39..ec618cc 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -24,31 +24,31 @@ def design(request_data: DesignModel): "basic": { "body_point_test": { "waistband_right": [ - 336, - 264 + 203, + 249 ], "hand_point_right": [ - 350, - 303 + 229, + 343 ], "waistband_left": [ - 245, - 274 + 119, + 248 ], "hand_point_left": [ - 219, - 315 + 97, + 343 ], "shoulder_left": [ - 227, - 155 + 108, + 107 ], "shoulder_right": [ - 338, - 149 + 212, + 107 ] }, - "layer_order": false, + "layer_order": true, "scale_bag": 0.7, "scale_earrings": 0.16, "self_template": true, @@ -57,62 +57,48 @@ def design(request_data: DesignModel): }, "items": [ { - "color": "229 214 200", - "icon": "none", - "image_id": 110205, + "businessId": 255303, + "color": "139 148 156", + "image_id": 95159, "offset": [ - 1, - 1 + 0, + 0 ], - "path": "aida-sys-image/images/female/trousers/0916000217.jpg", + "path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png", "print": { "IfSingle": false, - "print_path_list": [] + "location": [ + [ + 512.0, + 512.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" + ], + "print_scale_list": [ + 1.0 + ] }, + "priority": 10, "resize_scale": [ 1.0, 1.0 ], - "type": "Trousers" + "type": "Dress" }, { - "businessId": 493825, - "color": "209 125 29", - "elementId": 493825, - "icon": "none", - "image_id": 107101, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, "type": "Body" } ] } ], - "process_id": "6878547032381675" + "process_id": "89" } """ try: From 45a3597ffde15d34940ba162eded1a340a5b4b27 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 26 Jun 2024 16:42:10 +0800 Subject: [PATCH 078/108] =?UTF-8?q?feat=20=E6=96=B0=E5=A2=9E=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=8F=8F=E8=BF=B0=20docs=E9=A1=B5=E9=9D=A2=20?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9ES3=20=E5=9B=BE=E7=89=87get=20upload?= =?UTF-8?q?=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E6=95=B4=E7=90=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/api/api_generate_image.py | 2 +- app/service/generate_image/service_generate_relight_image.py | 2 +- app/service/utils/oss_client.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 2da5554..3f3646f 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -154,7 +154,7 @@ def generate_relight_image(request_item: GenerateProductImageModel, background_t { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - "image_url": "aida-sys-image/images/female/blouse/0628000098.jpg" + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png" } """ try: diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 8793c42..ca32c73 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -137,7 +137,7 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", prompt="Colorful black", - image_url='aida-users/89/product_image/123-89.png' + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png' ) server = GenerateRelightImage(rd) print(server.get_result()) diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 653d49d..6b8a8bd 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -63,7 +63,7 @@ if __name__ == '__main__': # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/product_image/string-89.png" - url = "aida-clothing/mask/mask_773e270b-3369-11ef-abe4-b0dcefbff887.png" + url = 'aida-users/89/relight_image/123-89.png' read_type = "PIL" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) From acb4678251fcb3911423940eced94cb937e2559d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 26 Jun 2024 17:52:53 +0800 Subject: [PATCH 079/108] =?UTF-8?q?feat=20print=20overall=20=E6=97=8B?= =?UTF-8?q?=E8=BD=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- .../design/items/pipelines/painting.py | 115 ++++++++++++++---- app/service/utils/oss_client.py | 2 +- 2 files changed, 91 insertions(+), 26 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 21b567f..49bbf01 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -204,7 +204,14 @@ class PrintPainting(object): result['print_image'] = result['pattern_image'] # print else: - painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) + if result['print']['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=result['print']['print_angle_list'][0], crop=True) + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) result['print_image'] = self.printpaint(result, painting_dict, print_=True) result['final_image'] = result['print_image'] canvas = np.full_like(result['final_image'], 255) @@ -351,8 +358,13 @@ class PrintPainting(object): dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) if not print_['IfSingle']: self.random_seed = random.randint(0, 1000) - painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) - painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 + if result['print']['print_angle_list'][0] != 0: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) else: painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) @@ -533,6 +545,52 @@ class PrintPainting(object): return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) # return rotated_img, (0, 0) + @staticmethod + def rotate_crop_image(img, angle, crop): + """ + angle: 旋转的角度 + crop: 是否需要进行裁剪,布尔向量 + """ + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] + w, h = img.shape[:2] + # 旋转角度的周期是360° + angle %= 360 + # 计算仿射变换矩阵 + M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + # 得到旋转后的图像 + img_rotated = cv2.warpAffine(img, M_rotation, (w, h)) + + # 如果需要去除黑边 + if crop: + # 裁剪角度的等效周期是180° + angle_crop = angle % 180 + if angle > 90: + angle_crop = 180 - angle_crop + # 转化角度为弧度 + theta = angle_crop * np.pi / 180 + # 计算高宽比 + hw_ratio = float(h) / float(w) + # 计算裁剪边长系数的分子项 + tan_theta = np.tan(theta) + numerator = np.cos(theta) + np.sin(theta) * np.tan(theta) + + # 计算分母中和高宽比相关的项 + r = hw_ratio if h > w else 1 / hw_ratio + # 计算分母项 + denominator = r * tan_theta + 1 + # 最终的边长系数 + crop_mult = numerator / denominator + + # 得到裁剪区域 + w_crop = int(crop_mult * w) + h_crop = int(crop_mult * h) + x0 = int((w - w_crop) / 2) + y0 = int((h - h_crop) / 2) + + img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop) + + return img_rotated + @staticmethod def read_image(image_url): image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") @@ -544,26 +602,33 @@ class PrintPainting(object): image_mode = "RGB" return image, image_mode - # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) - # # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] - # - # data_bytes = BytesIO(data.read()) - # image = Image.open(data_bytes) - # image_mode = image.mode - # # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - # if image_mode == "RGBA": - # # new_background = Image.new('RGB', image.size, (255, 255, 255)) - # # new_background.paste(image, mask=image.split()[3]) - # # image = new_background - # return image, image_mode - # image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) - # return image, "RGB" + @staticmethod + def resize_and_crop(img, target_width, target_height): + # 获取原始图像的尺寸 + original_height, original_width = img.shape[:2] - # @staticmethod - # def read_image(image_url): - # response = requests.get(image_url) - # image_data = np.frombuffer(response.content, np.uint8) - # - # # 解码图像 - # image = cv2.imdecode(image_data, 3) - # return image + # 计算目标尺寸的宽高比 + target_ratio = target_width / target_height + + # 计算原始图像的宽高比 + original_ratio = original_width / original_height + + # 调整尺寸 + if original_ratio > target_ratio: + # 原始图像更宽,按高度resize,然后裁剪宽度 + new_height = target_height + new_width = int(original_width * (target_height / original_height)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪宽度 + start_x = (new_width - target_width) // 2 + cropped_img = resized_img[:, start_x:start_x + target_width] + else: + # 原始图像更高,按宽度resize,然后裁剪高度 + new_width = target_width + new_height = int(original_height * (target_width / original_width)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪高度 + start_y = (new_height - target_height) // 2 + cropped_img = resized_img[start_y:start_y + target_height, :] + + return cropped_img diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 6b8a8bd..c2bb82c 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -63,7 +63,7 @@ if __name__ == '__main__': # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/product_image/string-89.png" - url = 'aida-users/89/relight_image/123-89.png' + url = "aida-results/result_c6520ce7-33a1-11ef-a8d3-b0dcefbff887.png" read_type = "PIL" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) From 8be6a64fbc6f1993d6cf8e6deb95fe0046ad360b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 10:36:14 +0800 Subject: [PATCH 080/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index cfc04f5..9e89e47 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "minio" +OSS = "S3" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From d30802e266e74011c59ce2d781ea6bc391e366cb Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 10:43:42 +0800 Subject: [PATCH 081/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/items/pipelines/painting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 49bbf01..2b8d50f 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -207,6 +207,8 @@ class PrintPainting(object): if result['print']['print_angle_list'][0] != 0: painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=result['print']['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=result['print']['print_angle_list'][0], crop=True) + # resize 到sketch大小 painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) From 7b751096467b877935d8c5fb114fba08279b4bed Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 10:46:27 +0800 Subject: [PATCH 082/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/items/pipelines/painting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 2b8d50f..cdccdc9 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -206,8 +206,8 @@ class PrintPainting(object): else: if result['print']['print_angle_list'][0] != 0: painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) - painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=result['print']['print_angle_list'][0], crop=True) - painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=result['print']['print_angle_list'][0], crop=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) # resize 到sketch大小 painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) From d09d8b740d6a58b43701a2d2e7aaf090508e7aee Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 10:55:27 +0800 Subject: [PATCH 083/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/items/pipelines/painting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index cdccdc9..1bbcb51 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -204,7 +204,7 @@ class PrintPainting(object): result['print_image'] = result['pattern_image'] # print else: - if result['print']['print_angle_list'][0] != 0: + if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) @@ -361,7 +361,7 @@ class PrintPainting(object): if not print_['IfSingle']: self.random_seed = random.randint(0, 1000) # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 - if result['print']['print_angle_list'][0] != 0: + if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) else: From deea2bc5e93186df3b3c007ef4ae05b0b167c3e2 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 11:03:35 +0800 Subject: [PATCH 084/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/model_process_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/model_process_service.py b/app/service/design/model_process_service.py index fffbd67..076e04d 100644 --- a/app/service/design/model_process_service.py +++ b/app/service/design/model_process_service.py @@ -13,7 +13,7 @@ def model_transpose(image_path): # new_data = [] for item in data: - if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + if item[0] >= 256 and item[1] >= 256 and item[2] >= 256: new_data.append((255, 255, 255, 0)) else: new_data.append(item) From cac569f7766e14aa483ed14911c178cb28ef880d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 11:13:34 +0800 Subject: [PATCH 085/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/model_process_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/model_process_service.py b/app/service/design/model_process_service.py index fffbd67..076e04d 100644 --- a/app/service/design/model_process_service.py +++ b/app/service/design/model_process_service.py @@ -13,7 +13,7 @@ def model_transpose(image_path): # new_data = [] for item in data: - if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + if item[0] >= 256 and item[1] >= 256 and item[2] >= 256: new_data.append((255, 255, 255, 0)) else: new_data.append(item) From ae2be271732dd0cd0daab7d9c7b44a67bb36cc82 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 15:04:38 +0800 Subject: [PATCH 086/108] =?UTF-8?q?feat=20oss=E6=9B=BF=E6=8D=A2=E4=B8=BAS3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design/items/pipelines/painting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 1bbcb51..a738455 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -58,6 +58,8 @@ class Painting(object): # 使用OpenCV解码图像数组 # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) return image @staticmethod From 1716dff372640f79e08ab20a1544503c5d815b0c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 27 Jun 2024 15:05:36 +0800 Subject: [PATCH 087/108] =?UTF-8?q?feat=20=E4=BF=AE=E5=A4=8D=E6=B8=90?= =?UTF-8?q?=E5=8F=98=E8=89=B2=E9=80=9A=E9=81=93=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 9e89e47..cfc04f5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "S3" +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 651a526254343a95294c704b12729b460efdb6b0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 11:11:28 +0800 Subject: [PATCH 088/108] =?UTF-8?q?feat=20fix=20redis=20host=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 9e89e47..ae91c23 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -48,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8" S3_REGION_NAME = "ap-east-1" # redis 配置 -REDIS_HOST = "10.1.1.150" +REDIS_HOST = "10.1.1.240" REDIS_PORT = "6379" REDIS_DB = "2" From a18d423d062ba5f7943ed834436f96ffc88140fc Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 11:11:55 +0800 Subject: [PATCH 089/108] =?UTF-8?q?feat=20fix=20redis=20host=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index cfc04f5..09e48fc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -48,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8" S3_REGION_NAME = "ap-east-1" # redis 配置 -REDIS_HOST = "10.1.1.150" +REDIS_HOST = "10.1.1.240" REDIS_PORT = "6379" REDIS_DB = "2" From d772adcd7a13da83b7de8f04d7197f612736fb5f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 13:59:53 +0800 Subject: [PATCH 090/108] =?UTF-8?q?feat=20fix=20redis=20host=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatgpt_for_translation.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 71d6e4f..4ade635 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,9 +1,8 @@ -import os +import logging from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI -from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, \ - PromptTemplate +from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from app.core.config import OPENAI_MODEL, OPENAI_API_KEY @@ -21,7 +20,8 @@ def translate_to_en(text): """You are a translation expert, proficient in various languages. And can translate various languages into English. Please translate to grammatically correct English regardless of the input language. - If the input is in English or numbers, check for grammatical errors. If there are no errors, output the input directly. + If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + output the input text exactly as it is without any modifications or additions. If there are grammatical errors, correct them and then output the sentence.""" ) system_message_prompt = SystemMessagePromptTemplate.from_template(template) @@ -36,34 +36,41 @@ def translate_to_en(text): ) translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template) - template = ( - """ - Input sentence: - {translate} - 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images, - ensuring all key nouns or adjectives related to the image are retained. - 2. Simplify complex sentence structures and clarify ambiguous expressions. - 3. Only Output the adjusted English sentence. + result = translate_chain.invoke(text) - Output : - """ - ) - # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words." - # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding. + logging.info("translate result : " + result.get('text')) + # print("translate result : " + result.get('text')) + return result.get('text') - prompt_template = PromptTemplate(input_variables=["translate"], template=template) - prompt_chain = LLMChain(llm=llm, prompt=prompt_template) - - from langchain.chains import SimpleSequentialChain - overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True) - - response = overall_chain.run(text) - return response + # template = ( + # """ + # Input sentence: + # {translate} + # 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images, + # ensuring all key nouns or adjectives related to the image are retained. + # 2. Simplify complex sentence structures and clarify ambiguous expressions. + # 3. Only Output the adjusted English sentence. + # + # Output : + # """ + # ) + # # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words." + # # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding. + # + # prompt_template = PromptTemplate(input_variables=["translate"], template=template) + # prompt_chain = LLMChain(llm=llm, prompt=prompt_template) + # + # from langchain.chains import SimpleSequentialChain + # overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True) + # + # response = overall_chain.run(text) + # return response def main(): """Main function""" - translate_to_en("生成一件运动风格的夹克,带有拉链和口袋,适合休闲穿着") + text = translate_to_en("fire") + print(text) if __name__ == "__main__": From f44e929b332fe07398269dc7ebd5fdbd0d0fe3e0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 14:03:09 +0800 Subject: [PATCH 091/108] =?UTF-8?q?feat=20fix=20redis=20host=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 09e48fc..ae91c23 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "minio" +OSS = "S3" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 1b70786784ae5bdd6a8a8cf9f978afd7ec5cd56e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 15:47:47 +0800 Subject: [PATCH 092/108] =?UTF-8?q?feat=20fix=20generate=20image=20sketch?= =?UTF-8?q?=20=E5=92=8Cprompt=20=E7=BB=84=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 1bc1c91..dac211c 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -40,16 +40,17 @@ class GenerateImage: if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) - self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - self.prompt = request_data.prompt + self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category + if self.category == "sketch": + self.prompt = f"{self.category},{self.prompt}" self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} From c0420cd6afb6008a4635e5d3431ad8431f1452b5 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 15:48:10 +0800 Subject: [PATCH 093/108] =?UTF-8?q?feat=20fix=20generate=20image=20sketch?= =?UTF-8?q?=20=E5=92=8Cprompt=20=E7=BB=84=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index ae91c23..09e48fc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "S3" +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 638447f31305a4d660d4727c71ad0752522e53d4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 16:58:52 +0800 Subject: [PATCH 094/108] =?UTF-8?q?feat=20generate=20to=20product=20image?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 1 + .../service_generate_product_image.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 4f85002..29f34d6 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,6 +20,7 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + image_strength: float class GenerateRelightImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index dcdf09f..6ee1bc6 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,17 +7,17 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np from PIL import Image, ImageOps -from minio import Minio from tritonclient.utils import np_to_triton_dtype + from app.core.config import * from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image @@ -37,6 +37,7 @@ class GenerateProductImage: self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" + self.image_strength = request_data.image_strength self.batch_size = 1 self.prompt = request_data.prompt self.image, self.image_size = pre_processing_image(request_data.image_url) @@ -74,13 +75,16 @@ class GenerateProductImage: text_obj = np.array(prompts, dtype="object").reshape(1) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image] + inputs = [input_text, input_image, input_image_strength] + input_image_strength.set_data_from_numpy(image_strength_obj) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 @@ -144,6 +148,7 @@ if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", prompt="", + image_strength=0.9, # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) From b0c5f21957d72d6c29eb0dca8dc93cdca798ba8a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 2 Jul 2024 10:06:50 +0800 Subject: [PATCH 095/108] =?UTF-8?q?feat=20generate=20to=20product=20image?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 3f3646f..41cf989 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -109,12 +109,14 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 + - **image_strength**: 生成强度,越低越接近原图 示例参数: { "tasks_id": "123-89", "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" + "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + "image_strength": 0.8 } """ try: From 4888935ef7d6fe41d5b6e3f694cb59b2c8f9e61f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 2 Jul 2024 10:07:26 +0800 Subject: [PATCH 096/108] =?UTF-8?q?feat=20generate=20to=20product=20image?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 09e48fc..ae91c23 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "minio" +OSS = "S3" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 35f5c6d4e9eb7d2c1ad5fb4f5e3c520d1f0e8be4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 2 Jul 2024 11:45:40 +0800 Subject: [PATCH 097/108] =?UTF-8?q?feat=20generate=20to=20product=20image?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fi?= =?UTF-8?q?x=20print=20=E9=80=8F=E6=98=8E=E5=9B=BE=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 4 ++-- app/service/design/items/pipelines/painting.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 41cf989..cce1300 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel @@ -145,7 +145,7 @@ def generate_product_image(tasks_id: str): @router.post("/generate_relight_image") -def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks): +def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks): """ 创建一个具有以下参数的请求体: - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index a738455..424a395 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -118,6 +118,7 @@ class PrintPainting(object): print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) From 589b3d0b7408050b424664c1f3f275b304ad69f3 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 2 Jul 2024 11:46:04 +0800 Subject: [PATCH 098/108] =?UTF-8?q?feat=20generate=20to=20product=20image?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fi?= =?UTF-8?q?x=20print=20=E9=80=8F=E6=98=8E=E5=9B=BE=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index ae91c23..09e48fc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "S3" +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 72428a73ab871e5fd4ab09002e86e546655ff1c1 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 3 Jul 2024 09:58:04 +0800 Subject: [PATCH 099/108] =?UTF-8?q?feat=20fix=20oss=20=E5=9B=9E=E8=B0=83?= =?UTF-8?q?=E8=87=B3S3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index ae91c23..09e48fc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -OSS = "S3" +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" From 48eaa9c2e3d1180f10112bdb2855484814baf487 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 3 Jul 2024 10:26:40 +0800 Subject: [PATCH 100/108] =?UTF-8?q?feat=20relight=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=96=B9=E5=90=91=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 4 +++- app/schemas/generate_image.py | 1 + .../generate_image/service_generate_relight_image.py | 9 ++++----- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index cce1300..82ad571 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -151,12 +151,14 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 + - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light 示例参数: { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png" + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "direction": "Right Light" } """ try: diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 29f34d6..2a16442 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -27,3 +27,4 @@ class GenerateRelightImageModel(BaseModel): tasks_id: str prompt: str image_url: str + direction: str diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index ca32c73..6f51435 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -7,16 +7,15 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np -from PIL import Image, ImageOps -from minio import Minio +from PIL import Image from tritonclient.utils import np_to_triton_dtype from app.core.config import * @@ -40,7 +39,7 @@ class GenerateRelightImage: self.prompt = request_data.prompt self.seed = "1" self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.direction = "Right Light" + self.direction = request_data.direction self.image_url = request_data.image_url self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id From 24142a01cc0f185099ae4f6f4271de8e7e90c5b5 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 4 Jul 2024 10:15:42 +0800 Subject: [PATCH 101/108] =?UTF-8?q?feat=20=E4=BF=AE=E6=94=B9design?= =?UTF-8?q?=E7=9A=84print=E9=80=BB=E8=BE=91=20=E4=BD=BF=20overall=20?= =?UTF-8?q?=E5=92=8C=20single=20=E5=90=8C=E6=97=B6=E5=AD=98=E5=9C=A8=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/painting.py | 251 ++++++++---------- 1 file changed, 111 insertions(+), 140 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 424a395..5936ccc 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -88,99 +88,112 @@ class PrintPainting(object): # @ RunTime def __call__(self, result): + single_print = result['print']['single'] + overall_print = result['print']['overall'] + element_print = result['print']['element'] - if "location" not in result['print'].keys(): - result['print']["location"] = [[0, 0]] - elif result['print']["location"] == [] or result['print']["location"] is None: - result['print']["location"] = [[0, 0]] - if result['print']['IfSingle']: - if len(result['print']['print_path_list']) > 0: - print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) - # print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8) - for i in range(len(result['print']['print_path_list'])): - image, image_mode = self.read_image(result['print']['print_path_list'][i]) - if image_mode == "RGBA": - new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i])) + if overall_print['print_path_list']: + painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} + result['print_image'] = result['pattern_image'] + if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) - mask = image.split()[3] - resized_source = image.resize(new_size) - resized_source_mask = mask.resize(new_size) + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['single_image'] = result['pattern_image'] = result['print_image'] - rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i]) + if single_print['print_path_list']: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(single_print['print_path_list'])): + image, image_mode = self.read_image(single_print['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) - source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) - source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) - source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) - source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) + rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i]) - print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) - mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) - ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 else: - mask = self.get_mask_inv(image) - mask = np.expand_dims(mask, axis=2) - mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) - mask = cv2.bitwise_not(mask) - # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i]) - # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) - x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1]) + start_x = x - image_x = print_background.shape[1] - image_y = print_background.shape[0] - print_x = rotate_image.shape[1] - print_y = rotate_image.shape[0] + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y - # 有bug - # if x + print_x > image_x: - # rotate_image = rotate_image[:, :x + print_x - image_x] - # rotate_mask = rotate_mask[:, :x + print_x - image_x] - # # - # if y + print_y > image_y: - # rotate_image = rotate_image[:y + print_y - image_y] - # rotate_mask = rotate_mask[:y + print_y - image_y] + # ------------------ + # 如果print-size大于image-size 则需要裁剪print - # 不能是并行 - # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 - # 先挪 再判断 最后裁剪 + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] - # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 - if x <= 0: - rotate_image = rotate_image[:, -x:] - rotate_mask = rotate_mask[:, -x:] - start_x = x = 0 - else: - start_x = x + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] - if y <= 0: - rotate_image = rotate_image[-y:, :] - rotate_mask = rotate_mask[-y:, :] - start_y = y = 0 - else: - start_y = y + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) - # ------------------ - # 如果print-size大于image-size 则需要裁剪print - - if x + print_x > image_x: - rotate_image = rotate_image[:, :image_x - x] - rotate_mask = rotate_mask[:, :image_x - x] - - if y + print_y > image_y: - rotate_image = rotate_image[:image_y - y, :] - rotate_mask = rotate_mask[:image_y - y, :] - - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) - - # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask - # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image - mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) - print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) @@ -198,54 +211,27 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - else: - painting_dict = {} - painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] - # no print - if len(result['print_dict']['print_path_list']) == 0 or not self.print_flag: - result['print_image'] = result['pattern_image'] - # print - else: - if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: - painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) - painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True) - painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True) - - # resize 到sketch大小 - painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) - painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) - else: - painting_dict = self.painting_collection(painting_dict, result, print_trigger=True) - result['print_image'] = self.printpaint(result, painting_dict, print_=True) - result['final_image'] = result['print_image'] - canvas = np.full_like(result['final_image'], 255) - temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) - tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) - temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) - tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) - result['single_image'] = cv2.add(tmp1, tmp2) - - if "element" in result.keys(): + if element_print['element_path_list']: print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) - for i in range(len(result['element']['element_path_list'])): - image, image_mode = self.read_image(result['element']['element_path_list'][i]) + for i in range(len(element_print['element_path_list'])): + image, image_mode = self.read_image(element_print['element_path_list'][i]) if image_mode == "RGBA": - new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i])) + new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) mask = image.split()[3] resized_source = image.resize(new_size) resized_source_mask = mask.resize(new_size) - rotated_resized_source = resized_source.rotate(-result['element']['element_angle_list'][i]) - rotated_resized_source_mask = resized_source_mask.rotate(-result['element']['element_angle_list'][i]) + rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i]) source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) - source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source) - source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask) + source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask) print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) @@ -256,10 +242,10 @@ class PrintPainting(object): mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) mask = cv2.bitwise_not(mask) # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) - x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1]) + x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) image_x = print_background.shape[1] image_y = print_background.shape[0] @@ -353,18 +339,18 @@ class PrintPainting(object): return print_background - def painting_collection(self, painting_dict, result, print_trigger=False): + def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): if print_trigger: - print_ = self.get_print(result['print_dict']) - painting_dict['Trigger'] = not print_['IfSingle'] - painting_dict['location'] = print_['location'] if 'location' in print_.keys() else None + print_ = self.get_print(print_dict) + painting_dict['Trigger'] = not is_single + painting_dict['location'] = print_['location'] single_mask_inv_print = self.get_mask_inv(print_['image']) dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) - if not print_['IfSingle']: + if not is_single: self.random_seed = random.randint(0, 1000) # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 - if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0: + if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) else: @@ -459,19 +445,11 @@ class PrintPainting(object): @staticmethod def get_print(print_dict): - if not 'print_scale_list' in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: print_dict['scale'] = 0.3 else: print_dict['scale'] = print_dict['print_scale_list'][0] - if not 'IfSingle' in print_dict.keys(): - print_dict['IfSingle'] = False - - # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - # data_bytes = BytesIO(data.read()) - # image = Image.open(data_bytes) - # image_mode = image.mode - bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] object_name = print_dict['print_path_list'][0].split("/", 1)[1] image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL") @@ -481,13 +459,6 @@ class PrintPainting(object): new_background.paste(image, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) - - # file = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]).data - # print_dict['image'] = cv2.imdecode(np.fromstring(file, np.uint8), 1) - - # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) - # return image - return print_dict def crop_image(self, image, image_size_h, image_size_w, location, print_shape): From eede15950792d0ce46cc4c65329a531d423af679 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 4 Jul 2024 14:14:57 +0800 Subject: [PATCH 102/108] =?UTF-8?q?feat=20product=20image=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9Eproduct=20type=20=E5=8F=82=E6=95=B0=20=EF=BC=8C?= =?UTF-8?q?=E8=A7=A3=E5=86=B3single=20item=20=E6=97=A0=E6=B3=95=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E5=A4=B4=E9=83=A8=E7=9A=84=E9=97=AE=E9=A2=98=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 10 ++- app/api/api_chat_robot.py | 4 +- app/api/api_design.py | 77 +++++++++++++----- app/api/api_design_pre_processing.py | 4 +- app/api/api_generate_image.py | 21 ++--- app/api/api_prompt_generation.py | 2 +- app/api/api_super_resolution.py | 4 +- app/api/api_test.py | 10 ++- app/core/config.py | 6 +- app/schemas/generate_image.py | 1 + app/service/chat_robot/script/main.py | 21 ++--- .../service_generate_product_image.py | 26 +++++-- logging_env.py | 78 +++++++++---------- 13 files changed, 163 insertions(+), 101 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index d9b210c..5c15efe 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -34,13 +34,14 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): ] """ try: - logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") + for item in request_item: + logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -65,10 +66,11 @@ def category_recognition(request_item: list[CategoryRecognitionModel]): ] """ try: - logger.info(f"category_recognition request item is : @@@@@@:{request_item}") + for item in request_item: + logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}") service = CategoryRecognition(request_data=request_item) data = service.get_result() - logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"category_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py index 6f3da16..c8bcf32 100644 --- a/app/api/api_chat_robot.py +++ b/app/api/api_chat_robot.py @@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel): } """ try: - logger.info(f"chat_robot request item is : @@@@@@:{request_data}") + logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}") data = chat(post_data=request_data) - logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"chat_robot Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_design.py b/app/api/api_design.py index ec618cc..5ce6096 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -66,22 +66,57 @@ def design(request_data: DesignModel): ], "path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png", "print": { - "IfSingle": false, - "location": [ - [ - 512.0, - 512.0 + "single": { + "location": [ + [ + 200.0, + 200.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png" + ], + "print_scale_list": [ + 1.0 ] - ], - "print_angle_list": [ - 0.0 - ], - "print_path_list": [ - "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" - ], - "print_scale_list": [ - 1.0 - ] + }, + "overall": { + "location": [ + [ + 512.0, + 512.0 + ] + ], + "print_angle_list": [ + 0.0 + ], + "print_path_list": [ + "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" + ], + "print_scale_list": [ + 1.0 + ] + }, + "element": { + "element_angle_list": [ + 0.0 + ], + "element_path_list": [ + "aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png" + ], + "element_scale_list": [ + 0.2731036750637755 + ], + "location": [ + [ + 228.63694825464364, + 406.4843844199667 + ] + ] + } }, "priority": 10, "resize_scale": [ @@ -102,9 +137,9 @@ def design(request_data: DesignModel): } """ try: - logger.info(f"design request item is : @@@@@@:{request_data.dict()}") + logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") data = generate(request_data=request_data) - logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"design response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -124,13 +159,13 @@ def get_progress(request_data: DesignProgressModel): } """ try: - logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}") + logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}") process_id = request_data.process_id r = Redis() data = r.read(key=process_id) if data is None: raise ValueError(f"No progress ID: {process_id}") - logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}") except Exception as e: logger.warning(f"get_progress Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -150,10 +185,10 @@ def model_process(request_data: ModelProgressModel): } """ try: - logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}") + logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}") data = model_transpose(image_path=request_data.model_path) - logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"model_process response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"model_process Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index f6946dc..eb2f3ab 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel): } """ try: - logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") + logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data)}") server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) - logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"design response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 82ad571..95d8c50 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -38,7 +38,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun } """ try: - logger.info(f"generate_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -52,7 +52,7 @@ def generate_image(tasks_id: str): try: logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") data = generate_image_infer_cancel(tasks_id) - logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -78,7 +78,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_ } """ try: - logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}") + logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateSingleLogoImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -92,7 +92,7 @@ def generate_single_logo_image(tasks_id: str): try: logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}") data = generate_single_logo_cancel(tasks_id) - logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_single_logo_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -110,17 +110,20 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 - **image_strength**: 生成强度,越低越接近原图 + - **product_type**: 输入single item 还是 overall item + 示例参数: { "tasks_id": "123-89", "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", - "image_strength": 0.8 + "image_strength": 0.8, + "product_type": "overall" } """ try: - logger.info(f"generate_product_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateProductImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -134,7 +137,7 @@ def generate_product_image(tasks_id: str): try: logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}") data = generate_product_image_cancel(tasks_id) - logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) @@ -162,7 +165,7 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t } """ try: - logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}") + logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}") service = GenerateRelightImage(request_item) background_tasks.add_task(service.get_result) except Exception as e: @@ -176,7 +179,7 @@ def generate_relight_image(tasks_id: str): try: logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}") data = generate_relight_image_cancel(tasks_id) - logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index c7bcbcd..c227b07 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -27,7 +27,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") data = translate_to_en(request_data.text) - logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_super_resolution.py b/app/api/api_super_resolution.py index 82b58f4..ce853fd 100644 --- a/app/api/api_super_resolution.py +++ b/app/api/api_super_resolution.py @@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg } """ try: - logger.info(f"super_resolution request item is : @@@@@@:{request_item}") + logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}") service = SuperResolution(request_item) background_tasks.add_task(service.sr_result) except Exception as e: @@ -41,7 +41,7 @@ def super_resolution(tasks_id: str): try: logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}") data = infer_cancel(tasks_id) - logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}") + logger.info(f"sr_cancel response @@@@@@:{data}") except Exception as e: logger.warning(f"sr_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_test.py b/app/api/api_test.py index 0ff521a..1271f95 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,8 +1,10 @@ +import json import logging -from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS -from fastapi import FastAPI, HTTPException +from fastapi import APIRouter +from fastapi import HTTPException + +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from app.schemas.response_template import ResponseModel logger = logging.getLogger() @@ -18,7 +20,7 @@ def test(id: int): "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, "local_oss_server": OSS } - logger.info(data) + logger.info(json.dumps(data)) if id == 1: raise HTTPException(status_code=404, detail="Item not found") diff --git a/app/core/config.py b/app/core/config.py index 09e48fc..4caaf13 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,9 +118,11 @@ GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") -# Generate Single Logo service config +# Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME = 'diffusion_ensemble_all' +GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' +GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' + GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 2a16442..7e7beb5 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -21,6 +21,7 @@ class GenerateProductImageModel(BaseModel): prompt: str image_url: str image_strength: float + product_type: str class GenerateRelightImageModel(BaseModel): diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py index 2a62664..1e64ca4 100644 --- a/app/service/chat_robot/script/main.py +++ b/app/service/chat_robot/script/main.py @@ -1,22 +1,23 @@ +import json import logging -from loguru import logger + from langchain.agents import Tool -from langchain.utilities import SerpAPIWrapper -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder -from langchain.schema import SystemMessage, AIMessage +from langchain.callbacks import FileCallbackHandler from langchain.chat_models import ChatOpenAI from langchain.llms.openai import OpenAI -from langchain.callbacks import FileCallbackHandler +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder +from langchain.schema import SystemMessage, AIMessage +from langchain.utilities import SerpAPIWrapper +from loguru import logger + +from app.core.config import * from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler from app.service.chat_robot.script.database import CustomDatabase +from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) -from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool -from app.core.config import * - -import os # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" @@ -110,5 +111,5 @@ def chat(post_data): 'completion_tokens': final_outputs['completion_tokens'], 'response_type': final_outputs['response_type'] } - logging.info(api_response) + logging.info(json.dumps(api_response)) return api_response diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 6ee1bc6..12964ae 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -39,6 +39,7 @@ class GenerateProductImage: self.category = "product_image" self.image_strength = request_data.image_strength self.batch_size = 1 + self.product_type = request_data.product_type self.prompt = request_data.prompt self.image, self.image_size = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id @@ -54,7 +55,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" @@ -73,9 +77,16 @@ class GenerateProductImage: self.image = cv2.resize(self.image, (512, 768)) images = [self.image.astype(np.uint8)] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape(1) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + if self.product_type == "single": + text_obj = np.array(prompts, dtype="object").reshape(-1, 1) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + else: + text_obj = np.array(prompts, dtype="object").reshape(1) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + + # 假设 prompts、images 和 self.image_strength 已经定义 input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") @@ -86,7 +97,11 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] input_image_strength.set_data_from_numpy(image_strength_obj) - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) + if self.product_type == "single": + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() @@ -151,6 +166,7 @@ if __name__ == '__main__': image_strength=0.9, # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + product_type="single" ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/logging_env.py b/logging_env.py index d618e37..08873b0 100644 --- a/logging_env.py +++ b/logging_env.py @@ -1,51 +1,51 @@ from app.core.config import LOGS_PATH LOGGER_CONFIG_DICT = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"} + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'} }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "level": "INFO", - "formatter": "simple", - "stream": "ext://sys.stdout", + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'stream': 'ext://sys.stdout', }, - "info_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "INFO", - "formatter": "simple", - "filename": f"{LOGS_PATH}info.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'info_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'INFO', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}info.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, - "error_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "ERROR", - "formatter": "simple", - "filename": f"{LOGS_PATH}error.log", - "maxBytes": 10485760, - "backupCount": 20, - "encoding": "utf8", + 'error_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'ERROR', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}error.log', + 'maxBytes': 10485760, + 'backupCount': 20, + 'encoding': 'utf8', }, - "debug_file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "level": "DEBUG", - "formatter": "simple", - "filename": f"{LOGS_PATH}debug.log", - "maxBytes": 10485760, - "backupCount": 50, - "encoding": "utf8", + 'debug_file_handler': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'DEBUG', + 'formatter': 'simple', + 'filename': f'{LOGS_PATH}debug.log', + 'maxBytes': 10485760, + 'backupCount': 50, + 'encoding': 'utf8', }, }, - "loggers": { - "my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"} + 'loggers': { + 'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'} }, - "root": { - "level": "INFO", - "handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"], + 'root': { + 'level': 'INFO', + 'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'], }, } From 1d29d15a7cd2d9786e0eee0c5a5dfcd282b94a7d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 4 Jul 2024 15:22:40 +0800 Subject: [PATCH 103/108] =?UTF-8?q?feat=20fix=20=20=E4=BF=AE=E5=A4=8Dprodu?= =?UTF-8?q?ct=20img=20=E5=A4=B4=E9=83=A8=E7=BC=BA=E5=A4=B1=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 12964ae..5ea6f83 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -136,24 +136,39 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 原始图片的尺寸 + width, height = image.size - # resize 图片内尺寸 并贴到768-512的纯白图像上 - target_height = 768 - target_width = 512 - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 计算长宽比为 3:2 的新尺寸 + desired_ratio = 2 / 3 + current_ratio = width / height + + if current_ratio > desired_ratio: + # 原始图片更宽,需要在上下添加 padding + new_width = width + new_height = int(width / desired_ratio) + else: + # 原始图片更高或者长宽比已经为 3:2 + new_height = height + new_width = int(height * desired_ratio) + + # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 + pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + + # 将原始图片粘贴到新的画布中心 + left = (new_width - width) // 2 + top = (new_height - height) // 2 + pad_image.paste(image, (left, top)) + + # 将画布 resize 成宽度 500,长度 750 + resized_image = pad_image.resize((500, 750)) + image_size = (512, 768) + + if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) + background = Image.new("RGB", image_size, (255, 255, 255)) # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) + background.paste(resized_image, mask=resized_image.split()[3]) image = np.array(background) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image, image_size @@ -162,11 +177,11 @@ def pre_processing_image(image_url): if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", - prompt="", + # prompt="", image_strength=0.9, - # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", - product_type="single" + product_type="overall" ) server = GenerateProductImage(rd) print(server.get_result()) From 0fa59a4f6fdb5a0192a51b65c4bf235cf51ccc2e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 4 Jul 2024 17:22:45 +0800 Subject: [PATCH 104/108] =?UTF-8?q?feat=20fix=20=20=E4=BF=AE=E5=A4=8Dprodu?= =?UTF-8?q?ct=20img=20=E5=A4=B4=E9=83=A8=E7=BC=BA=E5=A4=B1=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/painting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 5936ccc..938bd5b 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -91,7 +91,8 @@ class PrintPainting(object): single_print = result['print']['single'] overall_print = result['print']['overall'] element_print = result['print']['element'] - + result['single_image'] = None + result['print_image'] = None if overall_print['print_path_list']: painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} result['print_image'] = result['pattern_image'] @@ -106,7 +107,7 @@ class PrintPainting(object): else: painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) result['print_image'] = self.printpaint(result, painting_dict, print_=True) - result['single_image'] = result['pattern_image'] = result['print_image'] + result['single_image'] = result['pattern_image'] = result['print_image'] if single_print['print_path_list']: print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) From cc69ff78721fbbaffb8fbe9c3950f75f53756a92 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 4 Jul 2024 17:43:21 +0800 Subject: [PATCH 105/108] =?UTF-8?q?feat=20fix=20=20=E4=BF=AE=E5=A4=8Dprodu?= =?UTF-8?q?ct=20img=20=E5=A4=B4=E9=83=A8=E7=BC=BA=E5=A4=B1=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/painting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 938bd5b..224e753 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -107,7 +107,7 @@ class PrintPainting(object): else: painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) result['print_image'] = self.printpaint(result, painting_dict, print_=True) - result['single_image'] = result['pattern_image'] = result['print_image'] + result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] if single_print['print_path_list']: print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) From bc1c903d38d490fdeab1c716f23488c534caed39 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 5 Jul 2024 14:04:19 +0800 Subject: [PATCH 106/108] =?UTF-8?q?feat=20=20design=20=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=B8=80=E4=B8=AA=E6=B2=A1=E6=9C=89=E8=B4=B4?= =?UTF-8?q?single=20print=E7=9A=84=E4=B8=AD=E9=97=B4=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E7=9A=84url=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/clothing.py | 7 +++++-- app/service/design/items/pipelines/split.py | 5 +++++ app/service/design/service.py | 5 +++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py index 5adcc70..f9f9561 100644 --- a/app/service/design/items/clothing.py +++ b/app/service/design/items/clothing.py @@ -36,7 +36,9 @@ class Clothing(object): position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), - gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] + ) layer.insert(front_layer) @@ -51,7 +53,8 @@ class Clothing(object): position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), - gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "" + gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", + pattern_image_url=self.result['pattern_image_url'] ) layer.insert(back_layer) diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index 155347a..1e06712 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -71,6 +71,11 @@ class Split(object): result["back_image_url"] = None result["back_mask_url"] = None result['back_mask_image'] = None + + # 创建中间图层 + result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask']) + result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) + _, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') return result except Exception as e: logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/service.py b/app/service/design/service.py index 0ba5e72..54cb45b 100644 --- a/app/service/design/service.py +++ b/app/service/design/service.py @@ -1,10 +1,10 @@ +import concurrent.futures + from app.core.config import PRIORITY_DICT from app.service.design.core.layer import Layer from app.service.design.items import build_item from app.service.design.utils.redis_utils import Redis from app.service.design.utils.synthesis_item import synthesis, synthesis_single -import concurrent.futures - from app.service.utils.decorator import RunTime @@ -96,6 +96,7 @@ def process_object(cfg, process_id, total): 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", 'mask_url': lay['mask_url'], 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, # 'image': lay['image'], # 'mask_image': lay['mask_image'], From 9d0689d98ea204c6a83bbd7f8f51d94aaa0bf598 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 5 Jul 2024 15:45:48 +0800 Subject: [PATCH 107/108] =?UTF-8?q?feat=20fix=20=20relight=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9Esingle=20item=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 5 ++- app/core/config.py | 3 +- app/schemas/generate_image.py | 1 + .../service_generate_relight_image.py | 33 ++++++++++++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 95d8c50..3dee667 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -155,13 +155,16 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light + - **product_type**: 输入single item 还是 overall item + 示例参数: { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", - "direction": "Right Light" + "direction": "Right Light", + "product_type": "overall" } """ try: diff --git a/app/core/config.py b/app/core/config.py index 4caaf13..a01a2c0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -127,7 +127,8 @@ GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 7e7beb5..3dd7cf8 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -29,3 +29,4 @@ class GenerateRelightImageModel(BaseModel): prompt: str image_url: str direction: str + product_type: str diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 6f51435..e0729ba 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -38,6 +38,7 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "1" + self.product_type = request_data.product_type self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.direction = request_data.direction self.image_url = request_data.image_url @@ -55,7 +56,11 @@ class GenerateRelightImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == 'single': + image = result.as_numpy("generated_relight_image") + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -78,11 +83,18 @@ class GenerateRelightImage: nagetive_prompts = [self.negative_prompt] * self.batch_size directions = [self.direction] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) - seed_obj = np.array(seeds, dtype="object").reshape((1)) - direction_obj = np.array(directions, dtype="object").reshape((1)) + if self.product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") @@ -97,8 +109,11 @@ class GenerateRelightImage: input_direction.set_data_from_numpy(direction_obj) inputs = [input_text, input_natext, input_image, input_seed, input_direction] + if self.product_type == 'single': + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) - ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() @@ -136,7 +151,9 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", prompt="Colorful black", - image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png' + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + direction="Right Light", + product_type="single" ) server = GenerateRelightImage(rd) print(server.get_result()) From a86e885db889b2b345dc03383be3a5834a417eae Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 5 Jul 2024 17:30:05 +0800 Subject: [PATCH 108/108] =?UTF-8?q?feat=20fix=20=20=E4=BF=AE=E5=A4=8Ddesig?= =?UTF-8?q?n=20=E9=A2=84=E5=A4=84=E7=90=86=E9=83=A8=E5=88=86=E7=9A=84?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=97=A0=E6=B3=95=E8=BD=ACjson?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design_pre_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py index eb2f3ab..f260e22 100644 --- a/app/api/api_design_pre_processing.py +++ b/app/api/api_design_pre_processing.py @@ -30,7 +30,7 @@ def design_pre_processing(request_data: DesignPreProcessingModel): } """ try: - logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data)}") + logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}") server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) logger.info(f"design response @@@@@@:{json.dumps(data)}")