import logging import os import time import cv2 import numpy as np import torch import tritonclient.grpc as grpcclient from minio import Minio from pymilvus import MilvusClient from urllib3.exceptions import ResponseError from app.core.config import settings, SR_MODEL_NAME, SR_TRITON_URL, MILVUS_TABLE_KEYPOINT, KEYPOINT_RESULT_TABLE_FIELD_SET from app.schemas.pre_processing import DesignPreProcessingModel from app.service.design_fast.utils.design_ensemble import get_seg_result, get_keypoint_result from app.service.utils.new_oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class DesignPreprocessing: # def __init__(self): # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # @ RunTime def pipeline(self, image_list): sketches_list = self.read_image(image_list) # logging.info("read image success") bounding_box_sketches_list = self.bounding_box(sketches_list) # logging.info("bounding box image success") # super_resolution_list = self.super_resolution(bounding_box_sketches_list) # logging.info("super_resolution_list image success") infer_sketches_list = self.infer_image(bounding_box_sketches_list) # logging.info("infer image success") result = self.composing_image(infer_sketches_list) # logging.info("Replenish white edge image success") for d in result: if 'image_obj' in d: del d['image_obj'] if 'obj' in d: del d['obj'] if 'keypoint_result' in d: del d['keypoint_result'] return result @staticmethod def read_image(image_list): for obj in image_list: # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) image = oss_get_image(oss_client=minio_client, bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # 如果是四通道 mask image = image[:, :, :3] obj["image_obj"] = image return image_list # @ RunTime @staticmethod def bounding_box(image_list): for item in image_list: image = item['image_obj'] height, width = image.shape[:2] # 使用Canny边缘检测来检测物体的轮廓 edges = cv2.Canny(image, 50, 150) # 查找轮廓 contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 初始化包围所有外接矩形的大矩形的坐标 x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 # 遍历所有外接矩形,更新大矩形的坐标 for contour in contours: x, y, w, h = cv2.boundingRect(contour) x_min = min(x_min, x) y_min = min(y_min, y) x_max = max(x_max, x + w) y_max = max(y_max, y + h) # 根据大矩形的坐标来裁剪原始图像 if len(contours) > 0: cropped_image = image[y_min:y_max, x_min:x_max] item['obj'] = cropped_image # 新shape图像 else: item['obj'] = image padding_top = max(20 - y_min, 0) padding_bottom = max(20 - (height - y_max), 0) padding_left = max(20 - x_min, 0) padding_right = max(20 - (width - x_max), 0) # 添加padding padded_image = cv2.copyMakeBorder( image, padding_top, padding_bottom, padding_left, padding_right, cv2.BORDER_CONSTANT, value=(255, 255, 255) ) item['obj'] = padded_image return image_list @staticmethod def super_resolution(image_list): for item in image_list: # 判断 两边是否同时都小于512 因为此处做四倍超分 if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: # 如果任意一边小于256则超分 if item['obj'].shape[0] <= 200 or item['obj'].shape[1] <= 200: # 超分 img = item['obj'].astype(np.float32) / 255. sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample) triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) result_image = result.as_numpy(f'output')[0] sr_output = torch.tensor(result_image) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) item['obj'] = output try: # 覆盖到minio image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) bucket_name = item['image_url'].split("/", 1)[0] object_name = item['image_url'].split("/", 1)[1] oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: logging.warning(f"Error: {err}") return image_list # @ RunTime def infer_image(self, image_list): for sketch in image_list: # 小写 image_category = sketch['image_category'].lower() # 判断上下装 sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # 推理得到keypoint sketch['keypoint_result'] = self.keypoint_cache(sketch) if sketch['site'] == 'up': _, seg_cache = self.load_seg_result(sketch['image_id']) if not _: # 推理获得seg 结果 seg_result = get_seg_result(sketch['obj'])[0] self.save_seg_result(seg_result, sketch['image_id']) logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_result.shape}") else: logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_cache.shape}") return image_list # @ RunTime @staticmethod def composing_image(image_list): for image in image_list: ''' 比例相同 整合上下装代码''' image_width = image['obj'].shape[1] waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] scale = 0.4 if waist_width / scale >= image_width: add_width = int((waist_width / scale - image_width) / 2) ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" else: image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" # if image['site'] == 'down': # image_width = image['obj'].shape[1] # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] # scale = 0.4 # if waist_width / scale >= image_width: # add_width = int((waist_width / scale - image_width) / 2) # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) # if IF_DEBUG_SHOW: # cv2.imshow("composing_image", ret) # cv2.waitKey(0) # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_width = image['obj'].shape[1] # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] # scale = 0.4 # if waist_width / scale >= image_width: # add_width = int((waist_width / scale - image_width) / 2) # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) # if IF_DEBUG_SHOW: # cv2.imshow("composing_image", ret) # cv2.waitKey(0) # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" return image_list @staticmethod def load_seg_result(image_id): file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy" try: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: # logging.info("文件不存在") return False, None except Exception as e: logging.warning(f"加载失败: {e}") return False, None @staticmethod def save_seg_result(seg_result, image_id): file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) logging.debug(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: logging.warning(f"保存失败: {e}") def keypoint_cache(self, sketch): try: client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) keypoint_id = sketch['image_id'] res = client.query( collection_name=MILVUS_TABLE_KEYPOINT, # ids=[keypoint_id], filter=f"keypoint_id == {keypoint_id}", output_fields=['keypoint_vector', 'keypoint_site'] ) if len(res) == 0: # 没有结果 直接推理拿结果 并保存 keypoint_infer_result = self.infer_keypoint_result(sketch) return self.save_keypoint_cache(sketch, keypoint_infer_result) elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']: # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) elif res[0]["keypoint_site"] != sketch['site']: # 需要的类型和查询到的不一致,则更新类型为all keypoint_infer_result = self.infer_keypoint_result(sketch) return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector']) except Exception as e: logging.info(f"search keypoint cache milvus error {e}") return False # @ RunTime @staticmethod def infer_keypoint_result(sketch): keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果 return keypoint_infer_result @staticmethod # @ RunTime def save_keypoint_cache(sketch, keypoint_infer_result): if sketch['site'] == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, keypoint_infer_result.flatten()]) else: zeros = np.zeros(4, dtype=int) result = np.concatenate([keypoint_infer_result.flatten(), zeros]) # [ # [int(sketch['image_id'])], # [sketch['site']], # [result.tolist()] # ] try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.insert(data) # logging.info(f"save keypoint time : {time.time() - start_time}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) @staticmethod def update_keypoint_cache(sketch, infer_result, search_result): if sketch['site'] == "up": # 需要的是up 即推理出来的是up 那么查询的就是down result = np.concatenate([infer_result.flatten(), search_result[-4:]]) else: # 需要的是down 即推理出来的是down 那么查询的就是up result = np.concatenate([search_result[:20], infer_result.flatten()]) # [ # [int(sketch['image_id'])], # ["all"], # [result.tolist()] # ] try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) # start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.upsert(data) # logging.info(f"save keypoint time : {time.time() - start_time}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) if __name__ == '__main__': data = { "sketches": [ { "image_category": "blouse", "image_id": "123123123", "image_url": "test/0628000198.jpg" } ] } request_data = DesignPreProcessingModel(sketches=data["sketches"]) server = DesignPreprocessing() data = server.pipeline(image_list=request_data.sketches) print(data)