diff --git a/app/api/api_design_pre_processing.py b/app/api/api_design_pre_processing.py new file mode 100644 index 0000000..0c0089d --- /dev/null +++ b/app/api/api_design_pre_processing.py @@ -0,0 +1,29 @@ +import logging +import time + +from fastapi import APIRouter + +from app.schemas.pre_processing import DesignPreProcessingModel +from app.service.design_pre_processing.service import DesignPreprocessing + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/design_pre_processing") +def design_pre_processing(request_data: DesignPreProcessingModel): + try: + logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + server = DesignPreprocessing() + data = server.pipeline(image_list=request_data.sketches) + logger.info(f"design_pre_processing Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"design Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_route.py b/app/api/api_route.py index c1add93..c2bd2d2 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -7,6 +7,7 @@ from app.api import api_attribute_retrieve from app.api import api_design from app.api import api_chat_robot from app.api import api_prompt_generation +from app.api import api_design_pre_processing router = APIRouter() @@ -18,3 +19,4 @@ router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"] router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") +router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index 5744dec..cca1de0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,6 +118,9 @@ AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') +# DESIGN 预处理 +IF_DEBUG_SHOW = False + # 优先级 PRIORITY_DICT = { 'earring_front': 99, diff --git a/app/schemas/pre_processing.py b/app/schemas/pre_processing.py new file mode 100644 index 0000000..47d9297 --- /dev/null +++ b/app/schemas/pre_processing.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DesignPreProcessingModel(BaseModel): + sketches: list[dict] diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index fc59b61..4d0a081 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,8 +34,8 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - # 取消向量查询 直接过模型推理 keypoint_cache = self.keypoint_cache(result, site) + # 取消向量查询 直接过模型推理 # keypoint_cache = False if keypoint_cache is False: diff --git a/app/service/design/utils/design_ensemble.py b/app/service/design/utils/design_ensemble.py index e1df56a..a1021e9 100644 --- a/app/service/design/utils/design_ensemble.py +++ b/app/service/design/utils/design_ensemble.py @@ -37,7 +37,7 @@ def get_keypoint_result(image, site): keypoint_result = None try: image, scale_factor = keypoint_preprocess(image) - client = httpclient.InferenceServerClient(url=KEYPOINT_MODEL_URL) + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) transformed_img = image.astype(np.float32) inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py new file mode 100644 index 0000000..e655087 --- /dev/null +++ b/app/service/design_pre_processing/service.py @@ -0,0 +1,320 @@ +import logging +import time + +import cv2 +import numpy as np +import torch +from minio import Minio +from pymilvus import connections, Collection +from urllib3.exceptions import ResponseError +import torch.nn.functional as F +import tritonclient.grpc as grpcclient +import io + +from app.core.config import * +from app.service.design.utils.design_ensemble import get_keypoint_result + + +class DesignPreprocessing: + def __init__(self): + self.minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + # @ RunTime + def pipeline(self, image_list): + sketches_list = self.read_image(image_list) + logging.info("read image success") + + bounding_box_sketches_list = self.bounding_box(sketches_list) + logging.info("bounding box image success") + + super_resolution_list = self.super_resolution(bounding_box_sketches_list) + logging.info("super_resolution_list image success") + + infer_sketches_list = self.infer_image(super_resolution_list) + logging.info("infer image success") + + result = self.composing_image(infer_sketches_list) + logging.info("Replenish white edge image success") + + for d in result: + if 'image_obj' in d: + del d['image_obj'] + if 'obj' in d: + del d['obj'] + if 'keypoint_result' in d: + del d['keypoint_result'] + return result + + def read_image(self, image_list): + for obj in image_list: + file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif image.shape[2] == 4: # 如果是四通道 mask + image = image[:, :, :3] + obj["image_obj"] = image + return image_list + + # @ RunTime + def bounding_box(self, image_list): + for item in image_list: + image = item['image_obj'] + # 使用Canny边缘检测来检测物体的轮廓 + edges = cv2.Canny(image, 50, 150) + # 查找轮廓 + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 初始化包围所有外接矩形的大矩形的坐标 + x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 + # 遍历所有外接矩形,更新大矩形的坐标 + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + x_min = min(x_min, x) + y_min = min(y_min, y) + x_max = max(x_max, x + w) + y_max = max(y_max, y + h) + + if IF_DEBUG_SHOW: + image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) + cv2.imshow("bounding_box image", image_with_big_rect) + cv2.waitKey(0) + + # 根据大矩形的坐标来裁剪原始图像 + if len(contours) > 0: + cropped_image = image[y_min:y_max, x_min:x_max] + item['obj'] = cropped_image # 新shape图像 + # 取消直接覆盖,新增size判断 + # try: + # # 覆盖到minio + # image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes() + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + # except ResponseError as err: + # print(f"Error: {err}") + else: + item['obj'] = image + return image_list + + def super_resolution(self, image_list): + for item in image_list: + # 判断 两边是否同时都小于512 因为此处做四倍超分 + if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: + # 如果任意一边小于256则超分 + if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256: + # 超分 + img = item['obj'].astype(np.float32) / 255. + sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) + sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() + inputs = [ + grpcclient.InferInput("input", sample.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(sample) + triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) + result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) + result_image = result.as_numpy(f'output')[0] + sr_output = torch.tensor(result_image) + output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) + item['obj'] = output + try: + # 覆盖到minio + image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() + self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + except ResponseError as err: + print(f"Error: {err}") + return image_list + + # @ RunTime + def infer_image(self, image_list): + for sketch in image_list: + # 小写 + image_category = sketch['image_category'].lower() + # 判断上下装 + sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # 推理得到keypoint + sketch['keypoint_result'] = self.keypoint_cache(sketch) + + if IF_DEBUG_SHOW: + debug_show_image = sketch['obj'].copy() + points_list = [] + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for i in sketch['keypoint_result'].values(): + points_list.append((int(i[1]), int(i[0]))) + for point in points_list: + cv2.circle(debug_show_image, point, point_size, point_color, thickness) + cv2.imshow("", debug_show_image) + cv2.waitKey(0) + # # 关键点在上部则推理seg + # if sketch["site"] == "up": + # # 判断seg缓存是否存在,是否与当前图片shape一致 + # seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape) + # if seg_result is False: + # # 推理seg + 保存 + # seg_result = get_seg_result(sketch['image_id'], sketch['obj']) + return image_list + + # @ RunTime + def composing_image(self, image_list): + for image in image_list: + if image['site'] == 'down': + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image['obj'].shape[1]: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + scale = 0.4 + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1] + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + else: + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + return image_list + + @staticmethod + def select_seg_result(image_id, image_obj): + try: + # 如果shape不匹配 返回false + result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) + if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: + return result + else: + return False + except FileNotFoundError as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + @staticmethod + def search_seg_result(image_id, ori_shape): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. + # collection.load() + # start_time = time.time() + # res = collection.query( + # expr=f"seg_id == {image_id}", + # offset=0, + # limit=10, + # output_fields=["seg_cache"], + # ) + # logging.info(f"search seg cache time : {time.time() - start_time}") + + # if len(res): + # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) + # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) + # array_2d_exact = array_2d_exact.squeeze().numpy() + # return array_2d_exact + # else: + return False + except Exception as e: + logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") + return False + + def keypoint_cache(self, sketch): + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # collection.load() + start_time = time.time() + # res = collection.query( + # expr=f"keypoint_id == {sketch['image_id']}", + # offset=0, + # limit=1, + # output_fields=["keypoint_cache", "keypoint_site"], + # ) + res = [] + logging.info(f"search keypoint time : {time.time() - start_time}") + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.save_keypoint_cache(sketch, keypoint_infer_result) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != sketch['site']: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result = self.infer_keypoint_result(sketch) + return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector']) + except Exception as e: + logging.info(f"search keypoint cache milvus error {e}") + return False + + # @ RunTime + def infer_keypoint_result(self, sketch): + keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果 + return keypoint_infer_result + + @staticmethod + # @ RunTime + def save_keypoint_cache(sketch, keypoint_infer_result): + if sketch['site'] == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, keypoint_infer_result.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([keypoint_infer_result.flatten(), zeros]) + data = [ + [int(sketch['image_id'])], + [sketch['site']], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.insert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(sketch, infer_result, search_result): + if sketch['site'] == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + [int(sketch['image_id'])], + ["all"], + [result.tolist()] + ] + try: + # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) + start_time = time.time() + # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. + # mr = collection.upsert(data) + # logging.info(f"save keypoint time : {time.time() - start_time}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logging.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))