fix  重写所有resize代码,mmcv替换为cv
This commit is contained in:
zhouchengrong
2024-07-12 11:32:43 +08:00
parent cb4b2b4eef
commit db354cc02a
3 changed files with 24 additions and 22 deletions

View File

@@ -1,14 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import logging
from pprint import pprint from pprint import pprint
import torch
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from minio import Minio import torch
import tritonclient.http as httpclient import tritonclient.http as httpclient
from app.core.config import * from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image from app.service.utils.oss_client import oss_get_image
@@ -107,12 +107,8 @@ class AttributeRecognition:
@staticmethod @staticmethod
def preprocess(img): def preprocess(img):
img = mmcv.imread(img) img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224) img_scale = (224, 224)
scale_factor = [] img = cv2.resize(img, img_scale, interpolation=cv2.INTER_LINEAR)
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img return preprocessed_img

View File

@@ -27,7 +27,10 @@ from app.core.config import *
def keypoint_preprocess(img_path): def keypoint_preprocess(img_path):
img = mmcv.imread(img_path) img = mmcv.imread(img_path)
img_scale = (256, 256) img_scale = (256, 256)
img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True) h, w = img.shape[:2]
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = cv2.resize(img, img_scale, interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale) return preprocessed_img, (w_scale, h_scale)
@@ -80,9 +83,8 @@ def seg_preprocess(img_path):
img_scale_h = 1024 img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024 # 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h): if ori_shape != (img_scale_w, img_scale_h):
# TODO 取消代码中所有 关于mmcv的resize
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w)) img = cv2.resize(img, (img_scale_h, img_scale_w), interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape return preprocessed_img, ori_shape
@@ -130,12 +132,12 @@ def key_point_show(image_path, key_point_result=None):
if __name__ == '__main__': if __name__ == '__main__':
image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif") image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up") a = get_keypoint_result(image, "up")
new_list = [] new_list = []
print(list) print(list)
for i in a[0]: for i in a[0]:
new_list.append((int(i[0]), int(i[1]))) new_list.append((int(i[0]), int(i[1])))
key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list) key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image) # a = get_seg_result(1, image)
print(a) print(a)

View File

@@ -1,15 +1,14 @@
import logging import logging
import time import time
import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import tritonclient.http as httpclient import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
import cv2
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd, upload_face_png_sd from app.core.config import *
from app.service.generate_image.utils.upload_sd_image import upload_stain_png_sd, upload_face_png_sd
logger = logging.getLogger() logger = logging.getLogger()
@@ -17,11 +16,14 @@ logger = logging.getLogger()
def seg_preprocess(img_path): def seg_preprocess(img_path):
img = mmcv.imread(img_path) img = mmcv.imread(img_path)
ori_shape = img.shape[:2] ori_shape = img.shape[:2]
img_scale = ori_shape img_scale_w, img_scale_h = ori_shape
scale_factor = [] if ori_shape[0] > 1024:
img, x, y = mmcv.imresize(img, img_scale, return_scale=True) img_scale_w = 1024
scale_factor.append(x) if ori_shape[1] > 1024:
scale_factor.append(y) img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
img = cv2.resize(img, (img_scale_h, img_scale_w), interpolation=cv2.INTER_LINEAR)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape return preprocessed_img, ori_shape
@@ -105,6 +107,7 @@ def seg_infer_image(image_obj):
seg_result = seg_postprocess(inference_output1, ori_shape) seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result return seg_result
# def seg_postprocess(output, ori_shape): # def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) # seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1) # seg_logit = F.softmax(seg_logit, dim=1)
@@ -120,6 +123,7 @@ def seg_postprocess(output, ori_shape):
# seg_pred = output.cpu().numpy() # seg_pred = output.cpu().numpy()
return output[0] return output[0]
def remove_background(image): def remove_background(image):
image_obj, mask = get_mask(image) image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj) seg_result = seg_infer_image(image_obj)