From cfa2cd1987151cc84b2e05b956a9cdfa38cf98f8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 4 Sep 2024 15:05:05 +0800 Subject: [PATCH] =?UTF-8?q?feat=20fix=20=20=20design=20mask=20=E7=BA=A2?= =?UTF-8?q?=E7=BB=BF=E5=88=A4=E6=96=AD=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design/items/pipelines/segmentation.py | 17 ++++++----------- app/service/design/items/pipelines/split.py | 4 ++-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 2937df6..e7f09ed 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -19,22 +19,17 @@ class Segmentation(object): def __call__(self, result): if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") - seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0])) + seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST) # 转换颜色空间为 RGB(OpenCV 默认是 BGR) image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) - # 定义红色和绿色的颜色范围 - # 红色范围: 下界 [R-10, G-10, B-10], 上界 [R+10, G+10, B+10] - red_lower = np.array([50, 0, 0], dtype=np.uint8) - red_upper = np.array([255, 50, 50], dtype=np.uint8) - - # 绿色范围: 下界 [R-10, G-10, B-10], 上界 [R+10, G+10, B+10] - green_lower = np.array([0, 50, 0], dtype=np.uint8) - green_upper = np.array([50, 255, 50], dtype=np.uint8) + r, g, b = cv2.split(image_rgb) + red_mask = r > g + green_mask = g > r # 创建红色和绿色掩码 - result['front_mask'] = cv2.inRange(image_rgb, red_lower, red_upper) - result['back_mask'] = cv2.inRange(image_rgb, green_lower, green_upper) + result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255 + result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 result['mask'] = result['front_mask'] + result['back_mask'] else: # 本地查询seg 缓存是否存在 diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index 5fb568e..3485453 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -55,7 +55,7 @@ class Split(object): mask_pil.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.jpg", image_bytes=image_bytes) + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) result['mask_url'] = req.bucket_name + "/" + req.object_name else: rbga_mask = rgb_to_rgba(mask_image, front_mask) @@ -64,7 +64,7 @@ class Split(object): mask_pil.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.jpg", image_bytes=image_bytes) + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) result['mask_url'] = req.bucket_name + "/" + req.object_name result['back_image'] = None result["back_image_url"] = None