feat
fix design pipeline 剔除边缘检测任务,直接用分割
This commit is contained in:
@@ -10,6 +10,7 @@ class Bottom(Clothing):
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
# dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
dict(type='Scaling'),
|
||||
|
||||
@@ -9,18 +9,22 @@ from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
def __init__(self, device='cpu', show=False, debug=None):
|
||||
self.show = show
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
|
||||
self.save_seg_result(result['seg_result'][0], result['image_id'])
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -22,29 +22,10 @@ class Split(object):
|
||||
# KNet
|
||||
def __call__(self, result):
|
||||
try:
|
||||
if 'mask' not in result.keys():
|
||||
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
|
||||
if 'seg_result' not in result.keys(): # 没过seg模型
|
||||
result['front_mask'] = result['mask'].copy()
|
||||
result['back_mask'] = np.zeros_like(result['mask'])
|
||||
else:
|
||||
temp_front = result['seg_result'] == 1.0
|
||||
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = result['seg_result'] == 2.0
|
||||
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
if len(result['front_mask'].shape) > 2:
|
||||
front_mask = result['front_mask'][0]
|
||||
else:
|
||||
front_mask = result['front_mask']
|
||||
|
||||
if len(result['back_mask'].shape) > 2:
|
||||
back_mask = result['back_mask'][0]
|
||||
else:
|
||||
back_mask = result['back_mask']
|
||||
|
||||
# rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask)
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
|
||||
@@ -9,8 +9,8 @@ class Top(Clothing):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
|
||||
# dict(type='ContourDetection'),
|
||||
dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||
|
||||
Reference in New Issue
Block a user