#!/usr/bin/env python # -*- coding: UTF-8 -*- from pprint import pprint import cv2 import mmcv import numpy as np import pandas as pd import torch import tritonclient.http as httpclient from minio import Minio from app.core.config import settings, DESIGN_MODEL_URL from app.schemas.attribute_retrieve import AttributeRecognitionModel from app.service.utils.new_oss_client import oss_get_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class AttributeRecognition: def __init__(self, const, request_data): self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( { 'obj': self.preprocess(self.get_image(sketch.sketch_img_url)), 'category': sketch.category, 'colony': sketch.colony, 'sketch_img_url': sketch.sketch_img_url, } ) self.const = const self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") def get_result(self): for sketch in self.request_data: if sketch['category'] == "Tops" or sketch['category'] == "Blouse": attr_dict = {} for i in range(len(self.const.top_description_list)): attr_description = self.const.top_description_list[i] attr_model_path = self.const.top_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms": attr_dict = {} for i in range(len(self.const.bottom_description_list)): attr_description = self.const.bottom_description_list[i] attr_model_path = self.const.bottom_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Dress': attr_dict = {} for i in range(len(self.const.dress_description_list)): attr_description = self.const.dress_description_list[i] attr_model_path = self.const.dress_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Outwear': attr_dict = {} for i in range(len(self.const.outwear_description_list)): attr_description = self.const.outwear_description_list[i] attr_model_path = self.const.outwear_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) else: attr_dict = {} sketch['attr_dict'] = attr_dict del sketch['obj'] return self.request_data def get_attribute(self, model_name, description, image): attr_type = pd.read_csv(description) inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.triton_client.infer(model_name=model_name, inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f"output__0")) scores = inference_output.detach().numpy() colattr = list(attr_type['labelName']) task = description.split('/')[-1][:-4] maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] attr = { task: [] } for i in range(len(indexs)): atr = colattr[indexs[i]] attr[task].append(atr) return attr @staticmethod def merge(dict1, dict2): res = {**dict1, **dict2} return res @staticmethod def get_image(url): # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @staticmethod def preprocess(img): img = mmcv.imread(img) img_scale = (224, 224) img = cv2.resize(img, img_scale) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img if __name__ == '__main__': data = [ { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg" }, { "category": "Blouse", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" } ] from app.service.attribute.config import local_debug_const rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data] server = AttributeRecognition(local_debug_const, rq_data) pprint(server.get_result())