fix  重写所有resize代码,mmcv替换为cv
This commit is contained in:
zhouchengrong
2024-07-12 11:32:43 +08:00
parent cb4b2b4eef
commit db354cc02a
3 changed files with 24 additions and 22 deletions

View File

@@ -1,14 +1,14 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import logging
from pprint import pprint
import torch
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import torch
import tritonclient.http as httpclient
from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
@@ -107,12 +107,8 @@ class AttributeRecognition:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = cv2.resize(img, img_scale, interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -27,7 +27,10 @@ from app.core.config import *
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True)
h, w = img.shape[:2]
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = cv2.resize(img, img_scale, interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
@@ -80,9 +83,8 @@ def seg_preprocess(img_path):
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# TODO 取消代码中所有 关于mmcv的resize
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = cv2.resize(img, (img_scale_h, img_scale_w), interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -130,12 +132,12 @@ def key_point_show(image_path, key_point_result=None):
if __name__ == '__main__':
image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif")
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list)
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -1,15 +1,14 @@
import logging
import time
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
import cv2
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd, upload_face_png_sd
from app.core.config import *
from app.service.generate_image.utils.upload_sd_image import upload_stain_png_sd, upload_face_png_sd
logger = logging.getLogger()
@@ -17,11 +16,14 @@ logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = ori_shape
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
img = cv2.resize(img, (img_scale_h, img_scale_w), interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -105,6 +107,7 @@ def seg_infer_image(image_obj):
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
@@ -120,6 +123,7 @@ def seg_postprocess(output, ori_shape):
# seg_pred = output.cpu().numpy()
return output[0]
def remove_background(image):
image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj)