import logging import time import cv2 import numpy as np import torch import tritonclient.grpc as grpcclient from urllib3.exceptions import ResponseError from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result from app.service.utils.generate_uuid import generate_uuid from app.service.utils.oss_client import oss_get_image, oss_upload_image class DesignPreprocessing: # def __init__(self): # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # @ RunTime def pipeline(self, image_list): sketches_list = self.read_image(image_list) logging.info("read image success") bounding_box_sketches_list = self.bounding_box(sketches_list) logging.info("bounding box image success") super_resolution_list = self.super_resolution(bounding_box_sketches_list) logging.info("super_resolution_list image success") infer_sketches_list = self.infer_image(super_resolution_list) logging.info("infer image success") result = self.composing_image(infer_sketches_list) logging.info("Replenish white edge image success") for d in result: if 'image_obj' in d: del d['image_obj'] if 'obj' in d: del d['obj'] if 'keypoint_result' in d: del d['keypoint_result'] return result def read_image(self, image_list): for obj in image_list: # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # 如果是四通道 mask image = image[:, :, :3] obj["image_obj"] = image return image_list # @ RunTime def bounding_box(self, image_list): for item in image_list: image = item['image_obj'] # 使用Canny边缘检测来检测物体的轮廓 edges = cv2.Canny(image, 50, 150) # 查找轮廓 contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 初始化包围所有外接矩形的大矩形的坐标 x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 # 遍历所有外接矩形,更新大矩形的坐标 for contour in contours: x, y, w, h = cv2.boundingRect(contour) x_min = min(x_min, x) y_min = min(y_min, y) x_max = max(x_max, x + w) y_max = max(y_max, y + h) if IF_DEBUG_SHOW: image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) cv2.imshow("bounding_box image", image_with_big_rect) cv2.waitKey(0) # 根据大矩形的坐标来裁剪原始图像 if len(contours) > 0: cropped_image = image[y_min:y_max, x_min:x_max] item['obj'] = cropped_image # 新shape图像 # 取消直接覆盖,新增size判断 # try: # # 覆盖到minio # image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes() # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) # print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") # except ResponseError as err: # print(f"Error: {err}") else: item['obj'] = image return image_list def super_resolution(self, image_list): for item in image_list: # 判断 两边是否同时都小于512 因为此处做四倍超分 if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: # 如果任意一边小于256则超分 if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256: # 超分 img = item['obj'].astype(np.float32) / 255. sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample) triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) result_image = result.as_numpy(f'output')[0] sr_output = torch.tensor(result_image) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) item['obj'] = output try: # 覆盖到minio image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) bucket_name = item['image_url'].split("/", 1)[0] # 由于延迟,bounding box后的sketch与前端缓存的sketch有误差 object_name = item['image_url'].split("/", 1)[1] new_object = f"{object_name[:object_name.rfind('/') + 1]}{generate_uuid()}.{object_name.split('.', 1)[1]}" oss_upload_image(bucket=bucket_name, object_name=new_object, image_bytes=image_bytes) item['new_image_url'] = f"{bucket_name}/{new_object}" print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") return image_list # @ RunTime def infer_image(self, image_list): for sketch in image_list: # 小写 image_category = sketch['image_category'].lower() # 判断上下装 sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # 推理得到keypoint sketch['keypoint_result'] = self.keypoint_cache(sketch) if IF_DEBUG_SHOW: debug_show_image = sketch['obj'].copy() points_list = [] point_size = 1 point_color = (0, 0, 255) # BGR thickness = 4 # 可以为 0 、4、8 for i in sketch['keypoint_result'].values(): points_list.append((int(i[1]), int(i[0]))) for point in points_list: cv2.circle(debug_show_image, point, point_size, point_color, thickness) cv2.imshow("", debug_show_image) cv2.waitKey(0) # # 关键点在上部则推理seg # if sketch["site"] == "up": # # 判断seg缓存是否存在,是否与当前图片shape一致 # seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape) # if seg_result is False: # # 推理seg + 保存 # seg_result = get_seg_result(sketch['image_id'], sketch['obj']) return image_list # @ RunTime def composing_image(self, image_list): for image in image_list: ''' 比例相同 整合上下装代码''' image_width = image['obj'].shape[1] waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] scale = 0.4 if waist_width / scale >= image_width: add_width = int((waist_width / scale - image_width) / 2) ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) if IF_DEBUG_SHOW: cv2.imshow("composing_image", ret) cv2.waitKey(0) image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" else: image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" bucket_name = image['image_url'].split('/', 1)[0] object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.') oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) image['show_image_url'] = f"{bucket_name}/{object_name}" # if image['site'] == 'down': # image_width = image['obj'].shape[1] # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] # scale = 0.4 # if waist_width / scale >= image_width: # add_width = int((waist_width / scale - image_width) / 2) # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) # if IF_DEBUG_SHOW: # cv2.imshow("composing_image", ret) # cv2.waitKey(0) # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_width = image['obj'].shape[1] # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] # scale = 0.4 # if waist_width / scale >= image_width: # add_width = int((waist_width / scale - image_width) / 2) # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) # if IF_DEBUG_SHOW: # cv2.imshow("composing_image", ret) # cv2.waitKey(0) # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" # else: # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" # bucket_name = image['image_url'].split('/', 1)[0] # object_name = image['image_url'].split('/', 1)[1] # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) # image['show_image_url'] = f"{bucket_name}/{object_name}" return image_list @staticmethod def select_seg_result(image_id, image_obj): try: # 如果shape不匹配 返回false result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: return result else: return False except FileNotFoundError as e: logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") return False @staticmethod def search_seg_result(image_id, ori_shape): try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. # collection.load() # start_time = time.time() # res = collection.query( # expr=f"seg_id == {image_id}", # offset=0, # limit=10, # output_fields=["seg_cache"], # ) # logging.info(f"search seg cache time : {time.time() - start_time}") # if len(res): # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) # array_2d_exact = array_2d_exact.squeeze().numpy() # return array_2d_exact # else: return False except Exception as e: logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") return False def keypoint_cache(self, sketch): try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # collection.load() start_time = time.time() # res = collection.query( # expr=f"keypoint_id == {sketch['image_id']}", # offset=0, # limit=1, # output_fields=["keypoint_cache", "keypoint_site"], # ) res = [] logging.info(f"search keypoint time : {time.time() - start_time}") if len(res) == 0: # 没有结果 直接推理拿结果 并保存 keypoint_infer_result = self.infer_keypoint_result(sketch) return self.save_keypoint_cache(sketch, keypoint_infer_result) elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']: # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) elif res[0]["keypoint_site"] != sketch['site']: # 需要的类型和查询到的不一致,则更新类型为all keypoint_infer_result = self.infer_keypoint_result(sketch) return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector']) except Exception as e: logging.info(f"search keypoint cache milvus error {e}") return False # @ RunTime def infer_keypoint_result(self, sketch): keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果 return keypoint_infer_result @staticmethod # @ RunTime def save_keypoint_cache(sketch, keypoint_infer_result): if sketch['site'] == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, keypoint_infer_result.flatten()]) else: zeros = np.zeros(4, dtype=int) result = np.concatenate([keypoint_infer_result.flatten(), zeros]) data = [ [int(sketch['image_id'])], [sketch['site']], [result.tolist()] ] try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.insert(data) # logging.info(f"save keypoint time : {time.time() - start_time}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) @staticmethod def update_keypoint_cache(sketch, infer_result, search_result): if sketch['site'] == "up": # 需要的是up 即推理出来的是up 那么查询的就是down result = np.concatenate([infer_result.flatten(), search_result[-4:]]) else: # 需要的是down 即推理出来的是down 那么查询的就是up result = np.concatenate([search_result[:20], infer_result.flatten()]) data = [ [int(sketch['image_id'])], ["all"], [result.tolist()] ] try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.upsert(data) # logging.info(f"save keypoint time : {time.time() - start_time}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))