diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 8447514..786cf03 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -33,14 +33,20 @@ class Segmentation: result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 result['mask'] = result['front_mask'] + result['back_mask'] else: - # 本地查询seg 缓存是否存在 - _, seg_result = self.load_seg_result(result["image_id"]) - result['seg_result'] = seg_result - # 判断缓存和实际图片size是否相同 - if not _ or result["image"].shape[:2] != seg_result.shape: + # design信号判断 preview 不保存seg缓存 + if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 seg_result = get_seg_result(result["image_id"], result['image'])[0] - self.save_seg_result(seg_result, result['image_id']) + else: + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + # 判断缓存和实际图片size是否相同 + if not _ or result["image"].shape[:2] != seg_result.shape: + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + result['seg_result'] = seg_result + # 处理前片后片 temp_front = seg_result == 1.0 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))