fix  重写所有resize代码,mmcv替换为cv
This commit is contained in:
zhouchengrong
2024-07-12 11:32:43 +08:00
parent cb4b2b4eef
commit db354cc02a
3 changed files with 24 additions and 22 deletions

View File

@@ -1,15 +1,14 @@
import logging
import time
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
import cv2
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd, upload_face_png_sd
from app.core.config import *
from app.service.generate_image.utils.upload_sd_image import upload_stain_png_sd, upload_face_png_sd
logger = logging.getLogger()
@@ -17,11 +16,14 @@ logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = ori_shape
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
img = cv2.resize(img, (img_scale_h, img_scale_w), interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -105,6 +107,7 @@ def seg_infer_image(image_obj):
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
@@ -120,6 +123,7 @@ def seg_postprocess(output, ori_shape):
# seg_pred = output.cpu().numpy()
return output[0]
def remove_background(image):
image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj)