Files
sora_python/app/service/attribute_recognition/service.py

211 lines
8.8 KiB
Python
Raw Normal View History

2024-03-28 10:30:18 +08:00
from pprint import pprint
import cv2
2024-08-29 14:58:30 +08:00
import mmcv
2024-03-28 10:30:18 +08:00
import numpy as np
2024-08-29 14:58:30 +08:00
import pandas as pd
2024-03-28 10:30:18 +08:00
import torch
import tritonclient.http as httpclient
from minio import Minio
2024-08-29 14:58:30 +08:00
from skimage import transform
2024-03-28 10:30:18 +08:00
from app.core.config import MINIO_IP, MINIO_PORT, MINIO_SECURE, MINIO_ACCESS, MINIO_SECRET, ATT_TRITON_IP, ATT_TRITON_PORT
from app.service.attribute_recognition.config.category_mapping import category_mapping
2024-03-28 10:30:18 +08:00
def Merge(dict1, dict2):
res = {**dict1, **dict2}
return res
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, image, landmarks):
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w), mode='constant')
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return img, landmarks
class AttributeRecognition:
def __init__(self):
self.httpclient = httpclient.InferenceServerClient(url=f"{ATT_TRITON_IP}:{ATT_TRITON_PORT}")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
def get_image(self, url):
# Get data of an object.
# Read data from response.
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def preprocess(self, img_url):
img = self.get_image(img_url)
img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def attribute(self, args, request_data):
final_dict = {}
for l in range(len(request_data.upload_img_path)):
img_path = request_data.upload_img_path[l]
sample = self.preprocess(img_path)
category_model_path = args.category_model
category_description = args.category_discription
# 如果category为all-body 则通过category获取类别 主要用去区分为连体裤和连衣裙
mapping_category = category_mapping[request_data.upload_img_category[l]]
if mapping_category not in ["tops", "bottoms", "skirt", "dress", "outerwear"]:
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
else:
category_list = [mapping_category]
attr_dict = {}
2024-03-28 10:30:18 +08:00
if len(category_list) >= 1:
category = category_list[0]
print(category)
if category == 'tops':
2024-08-29 14:58:30 +08:00
attr_dict = {'Item': "top"}
2024-03-28 10:30:18 +08:00
for i in range(len(args.top_discription_list)):
attr_description = args.top_discription_list[i]
attr_model_path = args.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'bottoms':
2024-08-29 14:58:30 +08:00
attr_dict = {'Item': "bottom", 'Type': ['Pants']}
2024-03-28 10:30:18 +08:00
for i in range(len(args.bottom_discription_list)):
attr_description = args.bottom_discription_list[i]
attr_model_path = args.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'skirt':
2024-03-28 13:27:00 +08:00
attr_dict = {'Type': ['Skirt']}
2024-03-28 10:30:18 +08:00
category = 'bottom'
attr_dict['Item'] = category
for i in range(len(args.bottom_discription_list)):
attr_description = args.bottom_discription_list[i]
attr_model_path = args.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'dress':
attr_dict = {'Item': category}
for i in range(len(args.dress_discription_list)):
attr_description = args.dress_discription_list[i]
attr_model_path = args.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
2024-08-29 17:24:52 +08:00
elif category == 'outerwear':
2024-03-28 10:30:18 +08:00
attr_dict = {'Item': 'outer'}
for i in range(len(args.outwear_discription_list)):
attr_description = args.outwear_discription_list[i]
attr_model_path = args.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'jumpsuit':
attr_dict = {'Item': 'jumpsuit'}
for i in range(len(args.jumpsuit_discription_list)):
attr_description = args.jumpsuit_discription_list[i]
attr_model_path = args.jumpsuit_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
final_dict[request_data.upload_img_id[l]] = attr_dict
return final_dict
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.httpclient.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = attr_type['taskName'][0]
if task == 'category':
maxsc = np.max(scores)
if maxsc > 0:
indexs = np.argwhere(scores == maxsc)[:, 1]
else:
attr = {'category': []}
return attr
elif task[:-2] == 'Print' or task[:-2] == 'Material':
# maxsc = np.max(scores)
scores[scores > 0] = 1
scores[scores <= 0] = 0
indexs = np.argwhere(scores == 1)[:, 1]
# if maxsc> 0:
# indexs = np.argwhere(scores == maxsc)[:,1]
else:
maxsc = np.max(scores)
if maxsc > 0:
indexs = np.argwhere(scores == maxsc)[:, 1]
else:
indexs = []
# scores[scores > 0] = 1
# scores[scores <= 0] = 0
# indexs = np.argwhere(scores == 1)[:,1]
attr = {}
attr[task] = []
for i in range(len(indexs)):
# print('i and indexs: ', i, indexs, indexs[i], type(indexs[i]), colattr[indexs[i]])
atr = colattr[indexs[i]]
# print('corresponeding atr: ', atr)
attr[task].append(atr)
return attr
if __name__ == '__main__':
from app.service.attribute_recognition import const_debug
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"], "update_img_category": ["TOP/ONE PIECE"]}
2024-03-28 10:30:18 +08:00
service = AttributeRecognition()
pprint(service.attribute(const_debug, request_data))