Files
AiDA_Python/app/service/design_fast/pipeline/keypoint.py

118 lines
5.5 KiB
Python
Raw Normal View History

2024-09-19 14:20:56 +08:00
import logging
import numpy as np
2025-12-30 17:49:22 +08:00
# from pymilvus import MilvusClient
2024-09-19 14:20:56 +08:00
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
2024-09-26 06:09:05 +00:00
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
2024-09-26 14:26:30 +08:00
from app.service.utils.decorator import ClassCallRunTime, RunTime
2024-09-19 14:20:56 +08:00
logger = logging.getLogger(__name__)
class KeyPoint:
name = "KeyPoint"
@classmethod
def get_name(cls):
return cls.name
2024-09-26 14:23:54 +08:00
@ClassCallRunTime
2024-09-19 14:20:56 +08:00
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
2024-09-19 14:20:56 +08:00
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
2024-09-26 14:36:42 +08:00
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
2024-09-26 14:35:37 +08:00
# 取消向量查询 直接过模型推理
if not keypoint_cache:
2024-09-19 14:20:56 +08:00
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
return keypoint_infer_result, site
@staticmethod
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
2025-12-30 17:49:22 +08:00
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
2024-09-19 14:20:56 +08:00
2025-12-30 17:49:22 +08:00
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# client.close()
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
2024-09-19 14:20:56 +08:00
2025-12-30 17:49:22 +08:00
# @staticmethod
# def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
# if site == "up":
# # 需要的是up 即推理出来的是up 那么查询的就是down
# result = np.concatenate([infer_result.flatten(), search_result[-4:]])
# else:
# # 需要的是down 即推理出来的是down 那么查询的就是up
# result = np.concatenate([search_result[:20], infer_result.flatten()])
# data = [
# {"keypoint_id": keypoint_id,
# "keypoint_site": "all",
# "keypoint_vector": result.tolist()
# }
# ]
#
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(
# collection_name=MILVUS_TABLE_KEYPOINT,
# data=data
# )
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
2024-09-19 14:20:56 +08:00
2025-12-30 17:49:22 +08:00
# @RunTime
# def keypoint_cache(self, result, site):
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# keypoint_id = result['image_id']
# res = client.query(
# collection_name=MILVUS_TABLE_KEYPOINT,
# # ids=[keypoint_id],
# filter=f"keypoint_id == {keypoint_id}",
# output_fields=['keypoint_vector', 'keypoint_site']
# )
# if len(res) == 0:
# # 没有结果 直接推理拿结果 并保存
# keypoint_infer_result, site = self.infer_keypoint_result(result)
# return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
# elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# # 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
# elif res[0]["keypoint_site"] != site:
# # 需要的类型和查询到的不一致则更新类型为all
# keypoint_infer_result, site = self.infer_keypoint_result(result)
# return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
# except Exception as e:
# logger.info(f"search keypoint cache milvus error {e}")
# return False