Files
AiDA_Python/app/service/design/items/pipelines/keypoints.py
2024-07-19 15:17:54 +08:00

141 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import time
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.utils.decorator import RunTime, ClassCallRunTime
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@PIPELINES.register_module()
class KeypointDetection(object):
"""
path here: abstract path
"""
# def __init__(self):
# self.client = MilvusClient(
# uri="http://10.1.1.240:19530",
# token="root:Milvus",
# db_name=MILVUS_ALIAS
# )
# def __del__(self):
# start_time = time.time()
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ClassCallRunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
start_time = time.time()
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
# logging.info(f"infer keypoint time : {time.time() - start_time}")
return keypoint_infer_result, site
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# start_time = time.time()
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False