feat
fix design pipeline 剔除边缘检测任务,直接用分割
This commit is contained in:
@@ -10,6 +10,7 @@ class Bottom(Clothing):
|
|||||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||||
dict(type='KeypointDetection'),
|
dict(type='KeypointDetection'),
|
||||||
dict(type='ContourDetection'),
|
dict(type='ContourDetection'),
|
||||||
|
# dict(type='Segmentation'),
|
||||||
dict(type='Painting', painting_flag=True),
|
dict(type='Painting', painting_flag=True),
|
||||||
dict(type='PrintPainting', print_flag=True),
|
dict(type='PrintPainting', print_flag=True),
|
||||||
dict(type='Scaling'),
|
dict(type='Scaling'),
|
||||||
|
|||||||
@@ -9,18 +9,22 @@ from ...utils.design_ensemble import get_seg_result
|
|||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class Segmentation(object):
|
class Segmentation(object):
|
||||||
def __init__(self, device='cpu', show=False, debug=None):
|
|
||||||
self.show = show
|
|
||||||
self.device = device
|
|
||||||
self.debug = debug
|
|
||||||
|
|
||||||
# @ClassCallRunTime
|
# @ClassCallRunTime
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
|
# 本地查询seg 缓存是否存在
|
||||||
_, seg_result = self.load_seg_result(result["image_id"])
|
_, seg_result = self.load_seg_result(result["image_id"])
|
||||||
result['seg_result'] = seg_result
|
result['seg_result'] = seg_result
|
||||||
if not _:
|
if not _:
|
||||||
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
|
# 推理获得seg 结果
|
||||||
self.save_seg_result(result['seg_result'][0], result['image_id'])
|
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||||
|
self.save_seg_result(seg_result, result['image_id'])
|
||||||
|
# 处理前片后片
|
||||||
|
temp_front = seg_result == 1.0
|
||||||
|
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||||
|
temp_back = seg_result == 2.0
|
||||||
|
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||||
|
result['mask'] = result['front_mask'] + result['back_mask']
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -22,29 +22,10 @@ class Split(object):
|
|||||||
# KNet
|
# KNet
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
try:
|
try:
|
||||||
if 'mask' not in result.keys():
|
|
||||||
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
|
|
||||||
if 'seg_result' not in result.keys(): # 没过seg模型
|
|
||||||
result['front_mask'] = result['mask'].copy()
|
|
||||||
result['back_mask'] = np.zeros_like(result['mask'])
|
|
||||||
else:
|
|
||||||
temp_front = result['seg_result'] == 1.0
|
|
||||||
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
|
|
||||||
temp_back = result['seg_result'] == 2.0
|
|
||||||
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
|
|
||||||
|
|
||||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||||
if len(result['front_mask'].shape) > 2:
|
front_mask = result['front_mask']
|
||||||
front_mask = result['front_mask'][0]
|
back_mask = result['back_mask']
|
||||||
else:
|
|
||||||
front_mask = result['front_mask']
|
|
||||||
|
|
||||||
if len(result['back_mask'].shape) > 2:
|
|
||||||
back_mask = result['back_mask'][0]
|
|
||||||
else:
|
|
||||||
back_mask = result['back_mask']
|
|
||||||
|
|
||||||
# rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask)
|
|
||||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||||
rgba_image = cv2.resize(rgba_image, new_size)
|
rgba_image = cv2.resize(rgba_image, new_size)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ class Top(Clothing):
|
|||||||
pipeline = [
|
pipeline = [
|
||||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||||
dict(type='KeypointDetection'),
|
dict(type='KeypointDetection'),
|
||||||
dict(type='ContourDetection'),
|
# dict(type='ContourDetection'),
|
||||||
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
|
dict(type='Segmentation'),
|
||||||
dict(type='Painting', painting_flag=True),
|
dict(type='Painting', painting_flag=True),
|
||||||
dict(type='PrintPainting', print_flag=True),
|
dict(type='PrintPainting', print_flag=True),
|
||||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||||
|
|||||||
Reference in New Issue
Block a user