From 8363ec9ab3ce9b010e5ea9c07573e4daeafd4007 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 24 Jul 2024 15:21:06 +0800 Subject: [PATCH] =?UTF-8?q?feat=20fix=20=20design=20pipeline=20=E5=89=94?= =?UTF-8?q?=E9=99=A4=E8=BE=B9=E7=BC=98=E6=A3=80=E6=B5=8B=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=EF=BC=8C=E7=9B=B4=E6=8E=A5=E7=94=A8=E5=88=86=E5=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/bottom.py | 1 + .../design/items/pipelines/segmentation.py | 16 ++++++++----- app/service/design/items/pipelines/split.py | 23 ++----------------- app/service/design/items/top.py | 4 ++-- 4 files changed, 15 insertions(+), 29 deletions(-) diff --git a/app/service/design/items/bottom.py b/app/service/design/items/bottom.py index eb575fb..e01ec02 100644 --- a/app/service/design/items/bottom.py +++ b/app/service/design/items/bottom.py @@ -10,6 +10,7 @@ class Bottom(Clothing): dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), dict(type='KeypointDetection'), dict(type='ContourDetection'), + # dict(type='Segmentation'), dict(type='Painting', painting_flag=True), dict(type='PrintPainting', print_flag=True), dict(type='Scaling'), diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 8782e75..2966ee7 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -9,18 +9,22 @@ from ...utils.design_ensemble import get_seg_result @PIPELINES.register_module() class Segmentation(object): - def __init__(self, device='cpu', show=False, debug=None): - self.show = show - self.device = device - self.debug = debug # @ClassCallRunTime def __call__(self, result): + # 本地查询seg 缓存是否存在 _, seg_result = self.load_seg_result(result["image_id"]) result['seg_result'] = seg_result if not _: - result['seg_result'] = get_seg_result(result["image_id"], result['image']) - self.save_seg_result(result['seg_result'][0], result['image_id']) + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + # 处理前片后片 + temp_front = seg_result == 1.0 + result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) + temp_back = seg_result == 2.0 + result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) + result['mask'] = result['front_mask'] + result['back_mask'] return result @staticmethod diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index f3da4e7..5b7f1bc 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -22,29 +22,10 @@ class Split(object): # KNet def __call__(self, result): try: - if 'mask' not in result.keys(): - raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') - if 'seg_result' not in result.keys(): # 没过seg模型 - result['front_mask'] = result['mask'].copy() - result['back_mask'] = np.zeros_like(result['mask']) - else: - temp_front = result['seg_result'] == 1.0 - result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) - temp_back = result['seg_result'] == 2.0 - result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): - if len(result['front_mask'].shape) > 2: - front_mask = result['front_mask'][0] - else: - front_mask = result['front_mask'] - - if len(result['back_mask'].shape) > 2: - back_mask = result['back_mask'][0] - else: - back_mask = result['back_mask'] - - # rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask) + front_mask = result['front_mask'] + back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) rgba_image = cv2.resize(rgba_image, new_size) diff --git a/app/service/design/items/top.py b/app/service/design/items/top.py index 135328f..fc0d2a5 100644 --- a/app/service/design/items/top.py +++ b/app/service/design/items/top.py @@ -9,8 +9,8 @@ class Top(Clothing): pipeline = [ dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), dict(type='KeypointDetection'), - dict(type='ContourDetection'), - dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']), + # dict(type='ContourDetection'), + dict(type='Segmentation'), dict(type='Painting', painting_flag=True), dict(type='PrintPainting', print_flag=True), # dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),