Files
AiDA_Python/app/service/design_fast/pipeline/segmentation.py
zhouchengrong 60bf85bf88 feat 代码整理
fix
2024-09-25 11:40:11 +08:00

71 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design.utils.design_ensemble import get_seg_result
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info("保存成功", os.path.abspath(file_path))
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None