Files
AiDA_Python/app/service/attribute/service_att_recognition.py
2024-04-16 17:15:00 +08:00

168 lines
6.9 KiB
Python

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import logging
from pprint import pprint
import torch
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import tritonclient.http as httpclient
from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel
class AttributeRecognition:
def __init__(self, const, request_data):
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logging.info("实例化完成")
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
{
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
'category': sketch.category,
'colony': sketch.colony,
'sketch_img_url': sketch.sketch_img_url,
}
)
self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
def get_result(self):
for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
sketch['attr_dict'] = attr_dict
del sketch['obj']
return self.request_data
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
def get_image(self, url):
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
if __name__ == '__main__':
data = [
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
},
{
"category": "Blouse",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
}
]
from app.service.attribute.config import local_debug_const
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
server = AttributeRecognition(local_debug_const, rq_data)
pprint(server.get_result())