2024-04-16 15:51:03 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
|
|
|
import logging
|
|
|
|
|
from pprint import pprint
|
|
|
|
|
import torch
|
|
|
|
|
import cv2
|
|
|
|
|
import mmcv
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from minio import Minio
|
|
|
|
|
import tritonclient.http as httpclient
|
|
|
|
|
from app.core.config import *
|
|
|
|
|
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
2024-06-21 17:13:39 +08:00
|
|
|
from app.service.utils.oss_client import oss_get_image
|
2024-04-16 15:51:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttributeRecognition:
|
|
|
|
|
def __init__(self, const, request_data):
|
2024-06-21 17:13:39 +08:00
|
|
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
2024-04-16 15:51:03 +08:00
|
|
|
self.request_data = []
|
|
|
|
|
for i, sketch in enumerate(request_data):
|
|
|
|
|
self.request_data.append(
|
|
|
|
|
{
|
|
|
|
|
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
|
|
|
|
|
'category': sketch.category,
|
|
|
|
|
'colony': sketch.colony,
|
|
|
|
|
'sketch_img_url': sketch.sketch_img_url,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
self.const = const
|
|
|
|
|
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
|
|
|
|
|
|
|
|
|
|
def get_result(self):
|
|
|
|
|
for sketch in self.request_data:
|
|
|
|
|
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
for i in range(len(self.const.top_description_list)):
|
|
|
|
|
attr_description = self.const.top_description_list[i]
|
|
|
|
|
attr_model_path = self.const.top_model_list[i]
|
|
|
|
|
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
|
|
|
|
attr_dict = self.merge(attr_dict, present_dict)
|
|
|
|
|
|
|
|
|
|
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
for i in range(len(self.const.bottom_description_list)):
|
|
|
|
|
attr_description = self.const.bottom_description_list[i]
|
|
|
|
|
attr_model_path = self.const.bottom_model_list[i]
|
|
|
|
|
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
|
|
|
|
attr_dict = self.merge(attr_dict, present_dict)
|
|
|
|
|
|
|
|
|
|
elif sketch['category'] == 'Dress':
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
for i in range(len(self.const.dress_description_list)):
|
|
|
|
|
attr_description = self.const.dress_description_list[i]
|
|
|
|
|
attr_model_path = self.const.dress_model_list[i]
|
|
|
|
|
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
|
|
|
|
attr_dict = self.merge(attr_dict, present_dict)
|
|
|
|
|
|
|
|
|
|
elif sketch['category'] == 'Outwear':
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
for i in range(len(self.const.outwear_description_list)):
|
|
|
|
|
attr_description = self.const.outwear_description_list[i]
|
|
|
|
|
attr_model_path = self.const.outwear_model_list[i]
|
|
|
|
|
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
|
|
|
|
attr_dict = self.merge(attr_dict, present_dict)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
sketch['attr_dict'] = attr_dict
|
|
|
|
|
del sketch['obj']
|
|
|
|
|
return self.request_data
|
|
|
|
|
|
|
|
|
|
def get_attribute(self, model_name, description, image):
|
|
|
|
|
attr_type = pd.read_csv(description)
|
|
|
|
|
inputs = [
|
|
|
|
|
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
|
|
|
|
]
|
|
|
|
|
inputs[0].set_data_from_numpy(image, binary_data=True)
|
|
|
|
|
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
|
|
|
|
|
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
|
|
|
|
|
scores = inference_output.detach().numpy()
|
|
|
|
|
colattr = list(attr_type['labelName'])
|
|
|
|
|
task = description.split('/')[-1][:-4]
|
|
|
|
|
maxsc = np.max(scores[0][:5])
|
|
|
|
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
|
|
|
|
attr = {
|
|
|
|
|
task: []
|
|
|
|
|
}
|
|
|
|
|
for i in range(len(indexs)):
|
|
|
|
|
atr = colattr[indexs[i]]
|
|
|
|
|
attr[task].append(atr)
|
|
|
|
|
return attr
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def merge(dict1, dict2):
|
|
|
|
|
res = {**dict1, **dict2}
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def get_image(self, url):
|
2024-06-21 17:13:39 +08:00
|
|
|
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
|
|
|
|
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
|
|
|
|
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
|
|
|
|
|
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
2024-04-16 15:51:03 +08:00
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def preprocess(img):
|
|
|
|
|
img = mmcv.imread(img)
|
|
|
|
|
img_scale = (224, 224)
|
2024-07-12 13:16:26 +08:00
|
|
|
img = cv2.resize(img, img_scale)
|
2024-04-16 15:51:03 +08:00
|
|
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
|
|
|
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
|
|
|
|
return preprocessed_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
data = [
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Blouse",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"category": "Dress",
|
|
|
|
|
"colony": "Female",
|
|
|
|
|
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
from app.service.attribute.config import local_debug_const
|
|
|
|
|
|
|
|
|
|
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
|
|
|
|
|
server = AttributeRecognition(local_debug_const, rq_data)
|
|
|
|
|
pprint(server.get_result())
|