diff --git a/app/service/design_fast/pipeline/keypoint.py b/app/service/design_fast/pipeline/keypoint.py index 45debc2..dd2ebe5 100644 --- a/app/service/design_fast/pipeline/keypoint.py +++ b/app/service/design_fast/pipeline/keypoint.py @@ -5,6 +5,7 @@ from pymilvus import MilvusClient from app.core.config import * from app.service.design_fast.utils.design_ensemble import get_keypoint_result +from app.service.utils.decorator import ClassCallRunTime logger = logging.getLogger(__name__) @@ -16,6 +17,7 @@ class KeyPoint: def get_name(cls): return cls.name + @ClassCallRunTime def __call__(self, result): if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 # result['clothes_keypoint'] = self.infer_keypoint_result(result) diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 686e7b5..3884a48 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -6,6 +6,7 @@ import numpy as np from app.core.config import SEG_CACHE_PATH from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.decorator import ClassCallRunTime from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() @@ -15,6 +16,7 @@ class Segmentation: def __init__(self, minio_client): self.minio_client = minio_client + @ClassCallRunTime def __call__(self, result): if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")