代码整理
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pymilvus import MilvusClient
|
# from pymilvus import MilvusClient
|
||||||
|
|
||||||
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
||||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||||
@@ -54,63 +54,64 @@ class KeyPoint:
|
|||||||
"keypoint_vector": result.tolist()
|
"keypoint_vector": result.tolist()
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
|
||||||
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
|
||||||
client.close()
|
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"save keypoint cache milvus error : {e}")
|
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
||||||
|
|
||||||
@staticmethod
|
# try:
|
||||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
if site == "up":
|
# client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
# client.close()
|
||||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
# except Exception as e:
|
||||||
else:
|
# logger.info(f"save keypoint cache milvus error : {e}")
|
||||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
|
||||||
data = [
|
|
||||||
{"keypoint_id": keypoint_id,
|
|
||||||
"keypoint_site": "all",
|
|
||||||
"keypoint_vector": result.tolist()
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
# @staticmethod
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
# def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||||
client.upsert(
|
# if site == "up":
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
# # 需要的是up 即推理出来的是up 那么查询的就是down
|
||||||
data=data
|
# result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||||
)
|
# else:
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
# # 需要的是down 即推理出来的是down 那么查询的就是up
|
||||||
except Exception as e:
|
# result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||||
logger.info(f"save keypoint cache milvus error : {e}")
|
# data = [
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
# {"keypoint_id": keypoint_id,
|
||||||
|
# "keypoint_site": "all",
|
||||||
|
# "keypoint_vector": result.tolist()
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
|
# client.upsert(
|
||||||
|
# collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
|
# data=data
|
||||||
|
# )
|
||||||
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.info(f"save keypoint cache milvus error : {e}")
|
||||||
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
|
||||||
@RunTime
|
# @RunTime
|
||||||
def keypoint_cache(self, result, site):
|
# def keypoint_cache(self, result, site):
|
||||||
try:
|
# try:
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
keypoint_id = result['image_id']
|
# keypoint_id = result['image_id']
|
||||||
res = client.query(
|
# res = client.query(
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
# collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
# ids=[keypoint_id],
|
# # ids=[keypoint_id],
|
||||||
filter=f"keypoint_id == {keypoint_id}",
|
# filter=f"keypoint_id == {keypoint_id}",
|
||||||
output_fields=['keypoint_vector', 'keypoint_site']
|
# output_fields=['keypoint_vector', 'keypoint_site']
|
||||||
)
|
# )
|
||||||
if len(res) == 0:
|
# if len(res) == 0:
|
||||||
# 没有结果 直接推理拿结果 并保存
|
# # 没有结果 直接推理拿结果 并保存
|
||||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
# keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
# return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
# elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
# # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||||
elif res[0]["keypoint_site"] != site:
|
# elif res[0]["keypoint_site"] != site:
|
||||||
# 需要的类型和查询到的不一致,则更新类型为all
|
# # 需要的类型和查询到的不一致,则更新类型为all
|
||||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
# keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
# return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.info(f"search keypoint cache milvus error {e}")
|
# logger.info(f"search keypoint cache milvus error {e}")
|
||||||
return False
|
# return False
|
||||||
|
|||||||
Reference in New Issue
Block a user