2024-05-28 15:22:11 +08:00
|
|
|
|
import logging
|
|
|
|
|
|
import time
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from pymilvus import MilvusClient
|
|
|
|
|
|
|
|
|
|
|
|
from app.core.config import *
|
|
|
|
|
|
from ..builder import PIPELINES
|
|
|
|
|
|
from ...utils.design_ensemble import get_keypoint_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
|
|
|
|
class KeypointDetection(object):
|
|
|
|
|
|
"""
|
|
|
|
|
|
path here: abstract path
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-06-17 16:42:33 +08:00
|
|
|
|
# def __init__(self):
|
|
|
|
|
|
# self.client = MilvusClient(
|
|
|
|
|
|
# uri="http://10.1.1.240:19530",
|
|
|
|
|
|
# token="root:Milvus",
|
|
|
|
|
|
# db_name=MILVUS_ALIAS
|
|
|
|
|
|
# )
|
2024-05-28 15:22:11 +08:00
|
|
|
|
|
2024-06-17 16:42:33 +08:00
|
|
|
|
# def __del__(self):
|
|
|
|
|
|
# start_time = time.time()
|
|
|
|
|
|
# self.client.close()
|
|
|
|
|
|
# print(f"client close time : {time.time() - start_time}")
|
2024-05-28 15:22:11 +08:00
|
|
|
|
|
|
|
|
|
|
# @ RunTime
|
|
|
|
|
|
def __call__(self, result):
|
|
|
|
|
|
# logging.info("KeypointDetection run ")
|
|
|
|
|
|
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
|
|
|
|
|
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
|
|
|
|
|
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
|
|
|
|
|
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
|
|
|
|
|
|
|
|
|
|
|
keypoint_cache = self.keypoint_cache(result, site)
|
2024-05-30 09:48:13 +08:00
|
|
|
|
# 取消向量查询 直接过模型推理
|
2024-05-28 15:22:11 +08:00
|
|
|
|
# keypoint_cache = False
|
|
|
|
|
|
|
|
|
|
|
|
if keypoint_cache is False:
|
|
|
|
|
|
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
|
|
|
|
|
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
|
|
|
|
|
else:
|
|
|
|
|
|
result['clothes_keypoint'] = keypoint_cache
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def infer_keypoint_result(result):
|
|
|
|
|
|
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
|
|
|
|
|
# logging.info(f"infer keypoint time : {time.time() - start_time}")
|
|
|
|
|
|
return keypoint_infer_result, site
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
# @ RunTime
|
|
|
|
|
|
def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None):
|
|
|
|
|
|
if site == "down":
|
|
|
|
|
|
zeros = np.zeros(20, dtype=int)
|
|
|
|
|
|
result = np.concatenate([zeros, cache.flatten()])
|
|
|
|
|
|
else:
|
|
|
|
|
|
zeros = np.zeros(4, dtype=int)
|
|
|
|
|
|
result = np.concatenate([cache.flatten(), zeros])
|
|
|
|
|
|
# 取消向量保存 直接拿结果
|
|
|
|
|
|
data = [
|
|
|
|
|
|
{"keypoint_id": keypoint_id,
|
|
|
|
|
|
"keypoint_site": site,
|
|
|
|
|
|
"keypoint_vector": result.tolist()
|
|
|
|
|
|
}
|
|
|
|
|
|
]
|
|
|
|
|
|
try:
|
2024-06-17 16:42:33 +08:00
|
|
|
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
2024-05-28 15:22:11 +08:00
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
res = client.upsert(
|
|
|
|
|
|
collection_name=MILVUS_TABLE_KEYPOINT,
|
|
|
|
|
|
data=data,
|
|
|
|
|
|
)
|
|
|
|
|
|
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
2024-06-17 16:42:33 +08:00
|
|
|
|
client.close()
|
2024-05-28 15:22:11 +08:00
|
|
|
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.info(f"save keypoint cache milvus error : {e}")
|
|
|
|
|
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
|
|
|
|
|
if site == "up":
|
|
|
|
|
|
# 需要的是up 即推理出来的是up 那么查询的就是down
|
|
|
|
|
|
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 需要的是down 即推理出来的是down 那么查询的就是up
|
|
|
|
|
|
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
|
|
|
|
|
data = [
|
|
|
|
|
|
{"keypoint_id": keypoint_id,
|
|
|
|
|
|
"keypoint_site": "all",
|
|
|
|
|
|
"keypoint_vector": result.tolist()
|
|
|
|
|
|
}
|
|
|
|
|
|
]
|
2024-06-17 16:42:33 +08:00
|
|
|
|
|
2024-05-28 15:22:11 +08:00
|
|
|
|
try:
|
2024-06-17 16:42:33 +08:00
|
|
|
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
2024-05-28 15:22:11 +08:00
|
|
|
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
|
|
|
|
|
# mr = collection.upsert(data)
|
|
|
|
|
|
client.upsert(
|
|
|
|
|
|
collection_name=MILVUS_TABLE_KEYPOINT,
|
|
|
|
|
|
data=data
|
|
|
|
|
|
)
|
|
|
|
|
|
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
|
|
|
|
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.info(f"save keypoint cache milvus error : {e}")
|
|
|
|
|
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
|
|
|
|
|
|
|
|
|
|
# @ RunTime
|
|
|
|
|
|
def keypoint_cache(self, result, site):
|
|
|
|
|
|
try:
|
2024-06-17 16:42:33 +08:00
|
|
|
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
2024-05-28 15:22:11 +08:00
|
|
|
|
keypoint_id = result['image_id']
|
2024-06-17 16:42:33 +08:00
|
|
|
|
res = client.query(
|
2024-05-28 15:22:11 +08:00
|
|
|
|
collection_name=MILVUS_TABLE_KEYPOINT,
|
|
|
|
|
|
# ids=[keypoint_id],
|
|
|
|
|
|
filter=f"keypoint_id == {keypoint_id}",
|
|
|
|
|
|
output_fields=['keypoint_vector', 'keypoint_site']
|
|
|
|
|
|
)
|
|
|
|
|
|
if len(res) == 0:
|
|
|
|
|
|
# 没有结果 直接推理拿结果 并保存
|
|
|
|
|
|
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
|
|
|
|
|
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
|
|
|
|
|
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
|
|
|
|
|
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
|
|
|
|
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
|
|
|
|
|
elif res[0]["keypoint_site"] != site:
|
|
|
|
|
|
# 需要的类型和查询到的不一致,则更新类型为all
|
|
|
|
|
|
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
|
|
|
|
|
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.info(f"search keypoint cache milvus error {e}")
|
|
|
|
|
|
return False
|