Files
AiDA_Python/app/service/design_fast/pipeline/segmentation.py
zhouchengrong f6b2e283f9 feat(新功能):
fix(修复bug): RunTime判定修改为0.05秒内触发打印
docs(文档变更):
refactor(重构):
test(增加测试):
2025-01-10 15:04:13 +08:00

86 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None