Files
2024-08-30 12:36:56 +08:00

211 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pprint import pprint
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
from minio import Minio
from skimage import transform
from app.core.config import MINIO_IP, MINIO_PORT, MINIO_SECURE, MINIO_ACCESS, MINIO_SECRET, ATT_TRITON_IP, ATT_TRITON_PORT
from app.service.attribute_recognition.config.category_mapping import category_mapping
def Merge(dict1, dict2):
res = {**dict1, **dict2}
return res
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, image, landmarks):
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w), mode='constant')
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return img, landmarks
class AttributeRecognition:
def __init__(self):
self.httpclient = httpclient.InferenceServerClient(url=f"{ATT_TRITON_IP}:{ATT_TRITON_PORT}")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
def get_image(self, url):
# Get data of an object.
# Read data from response.
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def preprocess(self, img_url):
img = self.get_image(img_url)
img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def attribute(self, args, request_data):
final_dict = {}
for l in range(len(request_data.upload_img_path)):
img_path = request_data.upload_img_path[l]
sample = self.preprocess(img_path)
category_model_path = args.category_model
category_description = args.category_discription
# 如果category为all-body 则通过category获取类别 主要用去区分为连体裤和连衣裙
mapping_category = category_mapping[request_data.upload_img_category[l]]
if mapping_category not in ["tops", "bottoms", "skirt", "dress", "outerwear"]:
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
else:
category_list = [mapping_category]
attr_dict = {}
if len(category_list) >= 1:
category = category_list[0]
print(category)
if category == 'tops' or category == "top":
attr_dict = {'Item': "top"}
for i in range(len(args.top_discription_list)):
attr_description = args.top_discription_list[i]
attr_model_path = args.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'bottoms' or category == 'bottom':
attr_dict = {'Item': "bottom", 'Type': ['Pants']}
for i in range(len(args.bottom_discription_list)):
attr_description = args.bottom_discription_list[i]
attr_model_path = args.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'skirt':
attr_dict = {'Type': ['Skirt']}
category = 'bottom'
attr_dict['Item'] = category
for i in range(len(args.bottom_discription_list)):
attr_description = args.bottom_discription_list[i]
attr_model_path = args.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'dress':
attr_dict = {'Item': category}
for i in range(len(args.dress_discription_list)):
attr_description = args.dress_discription_list[i]
attr_model_path = args.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'outwear' or category == 'outerwear':
attr_dict = {'Item': 'outer'}
for i in range(len(args.outwear_discription_list)):
attr_description = args.outwear_discription_list[i]
attr_model_path = args.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'jumpsuit':
attr_dict = {'Item': 'jumpsuit'}
for i in range(len(args.jumpsuit_discription_list)):
attr_description = args.jumpsuit_discription_list[i]
attr_model_path = args.jumpsuit_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
final_dict[request_data.upload_img_id[l]] = attr_dict
return final_dict
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.httpclient.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = attr_type['taskName'][0]
if task == 'category':
maxsc = np.max(scores)
if maxsc > 0:
indexs = np.argwhere(scores == maxsc)[:, 1]
else:
attr = {'category': []}
return attr
elif task[:-2] == 'Print' or task[:-2] == 'Material':
# maxsc = np.max(scores)
scores[scores > 0] = 1
scores[scores <= 0] = 0
indexs = np.argwhere(scores == 1)[:, 1]
# if maxsc> 0:
# indexs = np.argwhere(scores == maxsc)[:,1]
else:
maxsc = np.max(scores)
if maxsc > 0:
indexs = np.argwhere(scores == maxsc)[:, 1]
else:
indexs = []
# scores[scores > 0] = 1
# scores[scores <= 0] = 0
# indexs = np.argwhere(scores == 1)[:,1]
attr = {}
attr[task] = []
for i in range(len(indexs)):
# print('i and indexs: ', i, indexs, indexs[i], type(indexs[i]), colattr[indexs[i]])
atr = colattr[indexs[i]]
# print('corresponeding atr: ', atr)
attr[task].append(atr)
return attr
if __name__ == '__main__':
from app.service.attribute_recognition import const_debug
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"], "update_img_category": ["TOP/ONE PIECE"]}
service = AttributeRecognition()
pprint(service.attribute(const_debug, request_data))