Files
AiDA_Python/app/service/attribute/service_att_recognition.py
zhouchengrong 2df1518a99 feat
fix  minio and s3
2024-06-21 17:13:39 +08:00

169 lines
7.0 KiB
Python

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import logging
from pprint import pprint
import torch
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import tritonclient.http as httpclient
from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
class AttributeRecognition:
def __init__(self, const, request_data):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
{
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
'category': sketch.category,
'colony': sketch.colony,
'sketch_img_url': sketch.sketch_img_url,
}
)
self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
def get_result(self):
for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
sketch['attr_dict'] = attr_dict
del sketch['obj']
return self.request_data
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
def get_image(self, url):
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
if __name__ == '__main__':
data = [
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
},
{
"category": "Blouse",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
}
]
from app.service.attribute.config import local_debug_const
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
server = AttributeRecognition(local_debug_const, rq_data)
pprint(server.get_result())