#!/usr/bin/env python # -*- coding: UTF-8 -*- from pprint import pprint import cv2 import mmcv import numpy as np import pandas as pd import torch import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( { 'obj': self.preprocess(self.get_image(sketch.sketch_img_url)), 'category': sketch.category, 'colony': sketch.colony, 'sketch_img_url': sketch.sketch_img_url, } ) self.const = const self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") def get_result(self): for sketch in self.request_data: if sketch['category'] == "Tops" or sketch['category'] == "Blouse": attr_dict = {} for i in range(len(self.const.top_description_list)): attr_description = self.const.top_description_list[i] attr_model_path = self.const.top_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms": attr_dict = {} for i in range(len(self.const.bottom_description_list)): attr_description = self.const.bottom_description_list[i] attr_model_path = self.const.bottom_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Dress': attr_dict = {} for i in range(len(self.const.dress_description_list)): attr_description = self.const.dress_description_list[i] attr_model_path = self.const.dress_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Outwear': attr_dict = {} for i in range(len(self.const.outwear_description_list)): attr_description = self.const.outwear_description_list[i] attr_model_path = self.const.outwear_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) else: attr_dict = {} sketch['attr_dict'] = attr_dict del sketch['obj'] return self.request_data def get_attribute(self, model_name, description, image): attr_type = pd.read_csv(description) inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.triton_client.infer(model_name=model_name, inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f"output__0")) scores = inference_output.detach().numpy() colattr = list(attr_type['labelName']) task = description.split('/')[-1][:-4] maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] attr = { task: [] } for i in range(len(indexs)): atr = colattr[indexs[i]] attr[task].append(atr) return attr @staticmethod def merge(dict1, dict2): res = {**dict1, **dict2} return res def get_image(self, url): # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @staticmethod def preprocess(img): img = mmcv.imread(img) img_scale = (224, 224) img = cv2.resize(img, img_scale, interpolation=cv2.INTER_LINEAR) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img if __name__ == '__main__': data = [ { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg" }, { "category": "Blouse", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" } ] from app.service.attribute.config import local_debug_const rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data] server = AttributeRecognition(local_debug_const, rq_data) pprint(server.get_result())