import logging import time import numpy as np from pymilvus import MilvusClient from app.core.config import * from app.service.utils.decorator import RunTime, ClassCallRunTime from ..builder import PIPELINES from ...utils.design_ensemble import get_keypoint_result @PIPELINES.register_module() class KeypointDetection(object): """ path here: abstract path """ # def __init__(self): # self.client = MilvusClient( # uri="http://10.1.1.240:19530", # token="root:Milvus", # db_name=MILVUS_ALIAS # ) # def __del__(self): # start_time = time.time() # self.client.close() # print(f"client close time : {time.time() - start_time}") # @ClassCallRunTime def __call__(self, result): # logging.info("KeypointDetection run ") if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 # result['clothes_keypoint'] = self.infer_keypoint_result(result) site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 # keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) else: result['clothes_keypoint'] = keypoint_cache return result @staticmethod def infer_keypoint_result(result): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' start_time = time.time() keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果 # logging.info(f"infer keypoint time : {time.time() - start_time}") return keypoint_infer_result, site @staticmethod # @ RunTime def save_keypoint_cache(keypoint_id, cache, site): if site == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, cache.flatten()]) else: zeros = np.zeros(4, dtype=int) result = np.concatenate([cache.flatten(), zeros]) # 取消向量保存 直接拿结果 data = [ {"keypoint_id": keypoint_id, "keypoint_site": site, "keypoint_vector": result.tolist() } ] try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) # start_time = time.time() res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) # logging.info(f"save keypoint time : {time.time() - start_time}") client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) @staticmethod def update_keypoint_cache(keypoint_id, infer_result, search_result, site): if site == "up": # 需要的是up 即推理出来的是up 那么查询的就是down result = np.concatenate([infer_result.flatten(), search_result[-4:]]) else: # 需要的是down 即推理出来的是down 那么查询的就是up result = np.concatenate([search_result[:20], infer_result.flatten()]) data = [ {"keypoint_id": keypoint_id, "keypoint_site": "all", "keypoint_vector": result.tolist() } ] try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.upsert(data) client.upsert( collection_name=MILVUS_TABLE_KEYPOINT, data=data ) # logging.info(f"save keypoint time : {time.time() - start_time}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) # @ RunTime def keypoint_cache(self, result, site): try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) keypoint_id = result['image_id'] res = client.query( collection_name=MILVUS_TABLE_KEYPOINT, # ids=[keypoint_id], filter=f"keypoint_id == {keypoint_id}", output_fields=['keypoint_vector', 'keypoint_site'] ) if len(res) == 0: # 没有结果 直接推理拿结果 并保存 keypoint_infer_result, site = self.infer_keypoint_result(result) return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) elif res[0]["keypoint_site"] != site: # 需要的类型和查询到的不一致,则更新类型为all keypoint_infer_result, site = self.infer_keypoint_result(result) return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) except Exception as e: logging.info(f"search keypoint cache milvus error {e}") return False