From b7e9131cf7556326f2e2d93e863f5013aa35513a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 24 Jul 2024 15:20:18 +0800 Subject: [PATCH] =?UTF-8?q?feat=20fix=20=20design=20=E6=81=A2=E5=A4=8Dseg?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=8E=A8=E7=90=86=EF=BC=8C=E5=B9=B6=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=9C=AC=E5=9C=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_pre_processing/service.py | 80 ++++++++------------ 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index f69c3ee..b6d868d 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -5,11 +5,12 @@ import cv2 import numpy as np import torch import tritonclient.grpc as grpcclient +from pymilvus import MilvusClient from urllib3.exceptions import ResponseError from app.core.config import * from app.schemas.pre_processing import DesignPreProcessingModel -from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.design.utils.design_ensemble import get_keypoint_result, get_seg_result from app.service.utils.oss_client import oss_get_image, oss_upload_image @@ -124,9 +125,9 @@ class DesignPreprocessing: bucket_name = item['image_url'].split("/", 1)[0] object_name = item['image_url'].split("/", 1)[1] oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) - print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: - print(f"Error: {err}") + logging.warning(f"Error: {err}") return image_list # @ RunTime @@ -138,6 +139,12 @@ class DesignPreprocessing: sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # 推理得到keypoint sketch['keypoint_result'] = self.keypoint_cache(sketch) + if sketch['site'] == 'up': + _, seg_cache = self.load_seg_result(sketch['image_id']) + if not _: + # 推理获得seg 结果 + seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0] + self.save_seg_result(seg_result, sketch['image_id']) if IF_DEBUG_SHOW: debug_show_image = sketch['obj'].copy() @@ -236,58 +243,37 @@ class DesignPreprocessing: return image_list @staticmethod - def select_seg_result(image_id, image_obj): + def load_seg_result(image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: - # 如果shape不匹配 返回false - result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) - if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: - return result - else: - return False - except FileNotFoundError as e: - logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") - return False + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + logging.info("文件不存在") + return False, None + except Exception as e: + logging.warning(f"加载失败: {e}") + return False, None @staticmethod - def search_seg_result(image_id, ori_shape): + def save_seg_result(seg_result, image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: - # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) - # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. - # collection.load() - # start_time = time.time() - # res = collection.query( - # expr=f"seg_id == {image_id}", - # offset=0, - # limit=10, - # output_fields=["seg_cache"], - # ) - # logging.info(f"search seg cache time : {time.time() - start_time}") - - # if len(res): - # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) - # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) - # array_2d_exact = array_2d_exact.squeeze().numpy() - # return array_2d_exact - # else: - return False + np.save(file_path, seg_result) + logging.info(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: - logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") - return False + logging.warning(f"保存失败: {e}") def keypoint_cache(self, sketch): try: - # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) - # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. - # collection.load() - start_time = time.time() - # res = collection.query( - # expr=f"keypoint_id == {sketch['image_id']}", - # offset=0, - # limit=1, - # output_fields=["keypoint_cache", "keypoint_site"], - # ) - res = [] - logging.info(f"search keypoint time : {time.time() - start_time}") + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + keypoint_id = sketch['image_id'] + res = client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) if len(res) == 0: # 没有结果 直接推理拿结果 并保存 keypoint_infer_result = self.infer_keypoint_result(sketch)