feat generate 升级 attribute retrieve 迁移
This commit is contained in:
170
app/service/attribute/service_att_recognition.py
Normal file
170
app/service/attribute/service_att_recognition.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
import logging
|
||||
from pprint import pprint
|
||||
import torch
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
import tritonclient.http as httpclient
|
||||
from app.core.config import *
|
||||
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
||||
|
||||
|
||||
class AttributeRecognition:
|
||||
def __init__(self, const, request_data):
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
logging.info("实例化完成")
|
||||
self.request_data = []
|
||||
for i, sketch in enumerate(request_data):
|
||||
self.request_data.append(
|
||||
{
|
||||
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
|
||||
'category': sketch.category,
|
||||
'colony': sketch.colony,
|
||||
'sketch_img_url': sketch.sketch_img_url,
|
||||
}
|
||||
)
|
||||
self.const = const
|
||||
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
|
||||
|
||||
def __del__(self):
|
||||
self.triton_client.close()
|
||||
|
||||
def get_result(self):
|
||||
for sketch in self.request_data:
|
||||
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.top_description_list)):
|
||||
attr_description = self.const.top_description_list[i]
|
||||
attr_model_path = self.const.top_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.bottom_description_list)):
|
||||
attr_description = self.const.bottom_description_list[i]
|
||||
attr_model_path = self.const.bottom_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif sketch['category'] == 'Dress':
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.dress_description_list)):
|
||||
attr_description = self.const.dress_description_list[i]
|
||||
attr_model_path = self.const.dress_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif sketch['category'] == 'Outwear':
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.outwear_description_list)):
|
||||
attr_description = self.const.outwear_description_list[i]
|
||||
attr_model_path = self.const.outwear_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
else:
|
||||
attr_dict = {}
|
||||
sketch['attr_dict'] = attr_dict
|
||||
del sketch['obj']
|
||||
return self.request_data
|
||||
|
||||
def get_attribute(self, model_name, description, image):
|
||||
attr_type = pd.read_csv(description)
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
|
||||
scores = inference_output.detach().numpy()
|
||||
colattr = list(attr_type['labelName'])
|
||||
task = description.split('/')[-1][:-4]
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
attr = {
|
||||
task: []
|
||||
}
|
||||
for i in range(len(indexs)):
|
||||
atr = colattr[indexs[i]]
|
||||
attr[task].append(atr)
|
||||
return attr
|
||||
|
||||
@staticmethod
|
||||
def merge(dict1, dict2):
|
||||
res = {**dict1, **dict2}
|
||||
return res
|
||||
|
||||
def get_image(self, url):
|
||||
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
@staticmethod
|
||||
def preprocess(img):
|
||||
img = mmcv.imread(img)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale = (224, 224)
|
||||
scale_factor = []
|
||||
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
||||
scale_factor.append(x)
|
||||
scale_factor.append(y)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = [
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Blouse",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
|
||||
},
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
}
|
||||
]
|
||||
from app.service.attribute.config import local_debug_const
|
||||
|
||||
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
|
||||
server = AttributeRecognition(local_debug_const, rq_data)
|
||||
pprint(server.get_result())
|
||||
Reference in New Issue
Block a user