#!/usr/bin/env python # -*- coding: UTF-8 -*- from pprint import pprint import cv2 import numpy as np import pandas as pd import torch import tritonclient.http as httpclient from minio import Minio from app.core.config import settings, DESIGN_MODEL_URL from app.schemas.attribute_retrieve import AttributeRecognitionModel from app.service.utils.image_normalize import my_imnormalize from app.service.utils.new_oss_client import oss_get_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class AttributeRecognition: def __init__(self, const, request_data): self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( { 'obj': self.preprocess(self.get_image(sketch.sketch_img_url)), 'category': sketch.category, 'colony': sketch.colony, 'sketch_img_url': sketch.sketch_img_url, } ) self.const = const self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") def get_result(self): for sketch in self.request_data: if sketch['category'] == "Tops" or sketch['category'] == "Blouse": attr_dict = {} for i in range(len(self.const.top_description_list)): attr_description = self.const.top_description_list[i] attr_model_path = self.const.top_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms": attr_dict = {} for i in range(len(self.const.bottom_description_list)): attr_description = self.const.bottom_description_list[i] attr_model_path = self.const.bottom_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Dress': attr_dict = {} for i in range(len(self.const.dress_description_list)): attr_description = self.const.dress_description_list[i] attr_model_path = self.const.dress_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Outwear': attr_dict = {} for i in range(len(self.const.outwear_description_list)): attr_description = self.const.outwear_description_list[i] attr_model_path = self.const.outwear_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) else: attr_dict = {} sketch['attr_dict'] = attr_dict del sketch['obj'] return self.request_data def get_attribute(self, model_name, description, image): attr_type = pd.read_csv(description) inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.triton_client.infer(model_name=model_name, inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f"output__0")) scores = inference_output.detach().numpy() colattr = list(attr_type['labelName']) task = description.split('/')[-1][:-4] maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] attr = { task: [] } for i in range(len(indexs)): atr = colattr[indexs[i]] attr[task].append(atr) return attr @staticmethod def merge(dict1, dict2): res = {**dict1, **dict2} return res @staticmethod def get_image(url): # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @staticmethod def preprocess(img): img_scale = (224, 224) img = cv2.resize(img, img_scale) img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img if __name__ == '__main__': data = [ { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg" }, { "category": "Blouse", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" } ] from app.service.attribute.config import local_debug_const rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data] server = AttributeRecognition(local_debug_const, rq_data) pprint(server.get_result())