Files
AiDA_Python/app/service/design_batch/pipeline/segmentation.py
zcr 18024a2d70
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
2025-12-30 16:49:08 +08:00

86 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import cv2
import numpy as np
from app.core.config import settings
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None