Files
AiDA_Python/app/service/attribute/service_att_recognition.py

168 lines
6.9 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from pprint import pprint
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class AttributeRecognition:
def __init__(self, const, request_data):
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
{
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
'category': sketch.category,
'colony': sketch.colony,
'sketch_img_url': sketch.sketch_img_url,
}
)
self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
def get_result(self):
for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
sketch['attr_dict'] = attr_dict
del sketch['obj']
return self.request_data
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
@staticmethod
def get_image(url):
2024-06-21 17:13:39 +08:00
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
if __name__ == '__main__':
data = [
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
},
{
"category": "Blouse",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
}
]
from app.service.attribute.config import local_debug_const
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
server = AttributeRecognition(local_debug_const, rq_data)
pprint(server.get_result())