代码整理
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2025-12-30 17:49:22 +08:00
parent aa57478852
commit 4951fab71a

View File

@@ -1,7 +1,7 @@
import logging import logging
import numpy as np import numpy as np
from pymilvus import MilvusClient # from pymilvus import MilvusClient
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result from app.service.design_fast.utils.design_ensemble import get_keypoint_result
@@ -54,63 +54,64 @@ class KeyPoint:
"keypoint_vector": result.tolist() "keypoint_vector": result.tolist()
} }
] ]
try: return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod # try:
def update_keypoint_cache(keypoint_id, infer_result, search_result, site): # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
if site == "up": # client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# 需要的是up 即推理出来的是up 那么查询的就是down # client.close()
result = np.concatenate([infer_result.flatten(), search_result[-4:]]) # except Exception as e:
else: # logger.info(f"save keypoint cache milvus error : {e}")
# 需要的是down 即推理出来的是down 那么查询的就是up # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try: # @staticmethod
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) # def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
client.upsert( # if site == "up":
collection_name=MILVUS_TABLE_KEYPOINT, # # 需要的是up 即推理出来的是up 那么查询的就是down
data=data # result = np.concatenate([infer_result.flatten(), search_result[-4:]])
) # else:
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) # # 需要的是down 即推理出来的是down 那么查询的就是up
except Exception as e: # result = np.concatenate([search_result[:20], infer_result.flatten()])
logger.info(f"save keypoint cache milvus error : {e}") # data = [
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) # {"keypoint_id": keypoint_id,
# "keypoint_site": "all",
# "keypoint_vector": result.tolist()
# }
# ]
#
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(
# collection_name=MILVUS_TABLE_KEYPOINT,
# data=data
# )
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@RunTime # @RunTime
def keypoint_cache(self, result, site): # def keypoint_cache(self, result, site):
try: # try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id'] # keypoint_id = result['image_id']
res = client.query( # res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT, # collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id], # # ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}", # filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site'] # output_fields=['keypoint_vector', 'keypoint_site']
) # )
if len(res) == 0: # if len(res) == 0:
# 没有结果 直接推理拿结果 并保存 # # 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result) # keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) # return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: # elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果 # # 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site: # elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all # # 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result) # keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) # return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e: # except Exception as e:
logger.info(f"search keypoint cache milvus error {e}") # logger.info(f"search keypoint cache milvus error {e}")
return False # return False