feat 代码整理

fix
This commit is contained in:
zhouchengrong
2024-09-26 14:23:54 +08:00
parent 84d207087b
commit d965352c20
2 changed files with 4 additions and 0 deletions

View File

@@ -5,6 +5,7 @@ from pymilvus import MilvusClient
from app.core.config import * from app.core.config import *
from app.service.design_fast.utils.design_ensemble import get_keypoint_result from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,6 +17,7 @@ class KeyPoint:
def get_name(cls): def get_name(cls):
return cls.name return cls.name
@ClassCallRunTime
def __call__(self, result): def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result) # result['clothes_keypoint'] = self.infer_keypoint_result(result)

View File

@@ -6,6 +6,7 @@ import numpy as np
from app.core.config import SEG_CACHE_PATH from app.core.config import SEG_CACHE_PATH
from app.service.design_fast.utils.design_ensemble import get_seg_result from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger() logger = logging.getLogger()
@@ -15,6 +16,7 @@ class Segmentation:
def __init__(self, minio_client): def __init__(self, minio_client):
self.minio_client = minio_client self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result): def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")