From 4951fab71a1ec7b8313ba104c66f36bb6c270dc5 Mon Sep 17 00:00:00 2001 From: zcr Date: Tue, 30 Dec 2025 17:49:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=95=B4=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/keypoint.py | 117 ++++++++++--------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/app/service/design_fast/pipeline/keypoint.py b/app/service/design_fast/pipeline/keypoint.py index 2b2607a..51f1fbc 100644 --- a/app/service/design_fast/pipeline/keypoint.py +++ b/app/service/design_fast/pipeline/keypoint.py @@ -1,7 +1,7 @@ import logging import numpy as np -from pymilvus import MilvusClient +# from pymilvus import MilvusClient from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings from app.service.design_fast.utils.design_ensemble import get_keypoint_result @@ -54,63 +54,64 @@ class KeyPoint: "keypoint_vector": result.tolist() } ] - try: - client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) - client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) - client.close() - return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - except Exception as e: - logger.info(f"save keypoint cache milvus error : {e}") - return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - @staticmethod - def update_keypoint_cache(keypoint_id, infer_result, search_result, site): - if site == "up": - # 需要的是up 即推理出来的是up 那么查询的就是down - result = np.concatenate([infer_result.flatten(), search_result[-4:]]) - else: - # 需要的是down 即推理出来的是down 那么查询的就是up - result = np.concatenate([search_result[:20], infer_result.flatten()]) - data = [ - {"keypoint_id": keypoint_id, - "keypoint_site": "all", - "keypoint_vector": result.tolist() - } - ] + # try: + # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) + # client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) + # client.close() + # except Exception as e: + # logger.info(f"save keypoint cache milvus error : {e}") + # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - try: - client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) - client.upsert( - collection_name=MILVUS_TABLE_KEYPOINT, - data=data - ) - return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - except Exception as e: - logger.info(f"save keypoint cache milvus error : {e}") - return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + # @staticmethod + # def update_keypoint_cache(keypoint_id, infer_result, search_result, site): + # if site == "up": + # # 需要的是up 即推理出来的是up 那么查询的就是down + # result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + # else: + # # 需要的是down 即推理出来的是down 那么查询的就是up + # result = np.concatenate([search_result[:20], infer_result.flatten()]) + # data = [ + # {"keypoint_id": keypoint_id, + # "keypoint_site": "all", + # "keypoint_vector": result.tolist() + # } + # ] + # + # try: + # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) + # client.upsert( + # collection_name=MILVUS_TABLE_KEYPOINT, + # data=data + # ) + # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + # except Exception as e: + # logger.info(f"save keypoint cache milvus error : {e}") + # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - @RunTime - def keypoint_cache(self, result, site): - try: - client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) - keypoint_id = result['image_id'] - res = client.query( - collection_name=MILVUS_TABLE_KEYPOINT, - # ids=[keypoint_id], - filter=f"keypoint_id == {keypoint_id}", - output_fields=['keypoint_vector', 'keypoint_site'] - ) - if len(res) == 0: - # 没有结果 直接推理拿结果 并保存 - keypoint_infer_result, site = self.infer_keypoint_result(result) - return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) - elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: - # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 - return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) - elif res[0]["keypoint_site"] != site: - # 需要的类型和查询到的不一致,则更新类型为all - keypoint_infer_result, site = self.infer_keypoint_result(result) - return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) - except Exception as e: - logger.info(f"search keypoint cache milvus error {e}") - return False + # @RunTime + # def keypoint_cache(self, result, site): + # try: + # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) + # keypoint_id = result['image_id'] + # res = client.query( + # collection_name=MILVUS_TABLE_KEYPOINT, + # # ids=[keypoint_id], + # filter=f"keypoint_id == {keypoint_id}", + # output_fields=['keypoint_vector', 'keypoint_site'] + # ) + # if len(res) == 0: + # # 没有结果 直接推理拿结果 并保存 + # keypoint_infer_result, site = self.infer_keypoint_result(result) + # return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) + # elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: + # # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + # elif res[0]["keypoint_site"] != site: + # # 需要的类型和查询到的不一致,则更新类型为all + # keypoint_infer_result, site = self.infer_keypoint_result(result) + # return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) + # except Exception as e: + # logger.info(f"search keypoint cache milvus error {e}") + # return False