#!/usr/bin/env python # -*- coding: UTF-8 -*- import logging from pprint import pprint import torch import cv2 import mmcv import numpy as np import pandas as pd from minio import Minio import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( { 'obj': self.preprocess(self.get_image(sketch.sketch_img_url)), 'category': sketch.category, 'colony': sketch.colony, 'sketch_img_url': sketch.sketch_img_url, } ) self.const = const self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") def get_result(self): for sketch in self.request_data: if sketch['category'] == "Tops" or sketch['category'] == "Blouse": attr_dict = {} for i in range(len(self.const.top_description_list)): attr_description = self.const.top_description_list[i] attr_model_path = self.const.top_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms": attr_dict = {} for i in range(len(self.const.bottom_description_list)): attr_description = self.const.bottom_description_list[i] attr_model_path = self.const.bottom_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Dress': attr_dict = {} for i in range(len(self.const.dress_description_list)): attr_description = self.const.dress_description_list[i] attr_model_path = self.const.dress_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) elif sketch['category'] == 'Outwear': attr_dict = {} for i in range(len(self.const.outwear_description_list)): attr_description = self.const.outwear_description_list[i] attr_model_path = self.const.outwear_model_list[i] present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj']) attr_dict = self.merge(attr_dict, present_dict) else: attr_dict = {} sketch['attr_dict'] = attr_dict del sketch['obj'] return self.request_data def get_attribute(self, model_name, description, image): attr_type = pd.read_csv(description) inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.triton_client.infer(model_name=model_name, inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f"output__0")) scores = inference_output.detach().numpy() colattr = list(attr_type['labelName']) task = description.split('/')[-1][:-4] maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] attr = { task: [] } for i in range(len(indexs)): atr = colattr[indexs[i]] attr[task].append(atr) return attr @staticmethod def merge(dict1, dict2): res = {**dict1, **dict2} return res def get_image(self, url): # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @staticmethod def preprocess(img): img = mmcv.imread(img) img_scale = (224, 224) img = cv2.resize(img, img_scale) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img if __name__ == '__main__': data = [ { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg" }, { "category": "Blouse", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg" }, { "category": "Dress", "colony": "Female", "sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" } ] from app.service.attribute.config import local_debug_const rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data] server = AttributeRecognition(local_debug_const, rq_data) pprint(server.get_result())