import logging import cv2 import numpy as np import pandas as pd import torch import torch.nn.functional as F import tritonclient.http as httpclient from minio import Minio from app.core.config import DESIGN_MODEL_URL, SEG_PRODUCT_MODEL_URL from app.core.config import settings from app.schemas.brand_dna import BrandDnaModel from app.service.attribute.config import const from app.service.utils.generate_uuid import generate_uuid from app.service.utils.image_normalize import my_imnormalize from app.service.utils.new_oss_client import oss_upload_image, oss_get_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) logger = logging.getLogger() class BrandDna: def __init__(self, request_item): self.sketch_bucket = "test" self.image_url = request_item.image_url self.is_brand_dna = request_item.is_brand_dna self.attr_type = pd.read_csv(settings.CATEGORY_PATH) # self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv") self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) self.seg_client = httpclient.InferenceServerClient(url=SEG_PRODUCT_MODEL_URL) self.const = const # self.const = local_debug_const # 获取结果 def get_result(self): mask, image = self.get_seg_mask() # cv2.imshow("", image) # cv2.waitKey(0) height, width, channels = image.shape result_dict = [] white_img = np.ones((height, width, channels), dtype=image.dtype) * 255 mask_image = np.zeros((height, width, 3)) for value in np.unique(mask): if value == 1: outwear_img = white_img.copy() outwear_mask_img = mask_image.copy() outwear_img[mask == value] = image[mask == value] outwear_mask_img[mask == value] = [0, 0, 255] # cv2.imshow("", outwear_img) # cv2.waitKey(0) # 预处理之后的input img preprocess_img = self.category_preprocess(outwear_img) # 类别检测 category = self.recognition_category(preprocess_img) if category == 'Trousers' or category == 'Skirt': male_category = 'Bottoms' elif category == 'Blouse' or category == 'Dress': male_category = 'Tops' else: male_category = 'Outwear' attribute = {} mask_url = "" img_url = "" # 属性检测 if self.is_brand_dna: attribute = self.get_recognition_attribute_result(category, preprocess_img) else: img_url = self.put_image(outwear_img, f"img/{generate_uuid()}") mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}") result_dict.append( { 'category_female': category, 'category_male': male_category, 'mask_url': mask_url, 'img_url': img_url, 'attribute': attribute } ) if value == 2: tops_img = white_img.copy() tops_mask_img = mask_image.copy() tops_img[mask == value] = image[mask == value] tops_mask_img[mask == value] = [0, 0, 255] # cv2.imshow("", tops_img) # cv2.waitKey(0) # 预处理之后的input img preprocess_img = self.category_preprocess(tops_img) # 类别检测 category = self.recognition_category(preprocess_img) if category == 'Trousers' or category == 'Skirt': male_category = 'Bottoms' elif category == 'Blouse' or category == 'Dress': male_category = 'Tops' else: male_category = 'Outwear' # 属性检测 attribute = {} img_url = "" mask_url = "" # 属性检测 if self.is_brand_dna: attribute = self.get_recognition_attribute_result(category, preprocess_img) else: mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}") img_url = self.put_image(tops_img, f"img/{generate_uuid()}") result_dict.append( { 'category_female': category, 'category_male': male_category, 'mask_url': mask_url, 'img_url': img_url, 'attribute': attribute } ) if value == 3: bottoms_img = white_img.copy() bottoms_mask_img = mask_image.copy() bottoms_img[mask == value] = image[mask == value] bottoms_mask_img[mask == value] = [0, 0, 255] # cv2.imshow("", bottoms_img) # cv2.waitKey(0) # 预处理之后的input img preprocess_img = self.category_preprocess(bottoms_img) # 类别检测 category = self.recognition_category(preprocess_img) if category == 'Trousers' or category == 'Skirt': male_category = 'Bottoms' elif category == 'Blouse' or category == 'Dress': male_category = 'Tops' else: male_category = 'Outwear' attribute = {} img_url = "" mask_url = "" # 属性检测 if self.is_brand_dna: attribute = self.get_recognition_attribute_result(category, preprocess_img) else: img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}") mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}") result_dict.append( { 'category_female': category, 'category_male': male_category, 'mask_url': mask_url, 'img_url': img_url, 'attribute': attribute } ) return result_dict # 获取product mask def get_seg_mask(self): input_image = self.get_image() input_img, ori_shape = self.seg_product_preprocess(input_image) transformed_img = input_img.astype(np.float32) inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)] results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs) inference_output1 = results.as_numpy("seg_output__0") mask = self.product_postprocess(inference_output1, ori_shape)[0] return mask, input_image # 获取图片 def get_image(self): image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") # 将其转换为彩色图像 if len(image.shape) == 3 and image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) elif len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) return image # return cv2.imread(self.image_url) # 图片上传 def put_image(self, image, object_name): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes) return f"{self.sketch_bucket}/{object_name}.jpg" except Exception as e: logger.warning(e) # 服装分割预处理 @staticmethod def seg_product_preprocess(image): img = image ori_shape = img.shape[:2] img_scale_w, img_scale_h = ori_shape if ori_shape[0] > 1024: img_scale_w = 1024 if ori_shape[1] > 1024: img_scale_h = 1024 # 如果图片size任意一边 大于 1024, 则会resize 成1024 if ori_shape != (img_scale_w, img_scale_h): # my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape # 类别检测后处理 @staticmethod def product_postprocess(output, ori_shape): seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) seg_pred = seg_logit.cpu().numpy() return seg_pred[0] # 类别检测模型预处理 @staticmethod def category_preprocess(img): # ori_shape = img.shape[:2] img_scale = (224, 224) img = cv2.resize(img, img_scale) img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img # 类别检测 def recognition_category(self, image): inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f'output__0')) scores = inference_output.detach().numpy() colattr = list(self.attr_type['labelName']) maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] return colattr[indexs[0]] # 属性检测 def recognition_attribute(self, model_name, description, image): attr_type = pd.read_csv(description) inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(image, binary_data=True) results = self.att_client.infer(model_name=model_name, inputs=inputs) inference_output = torch.from_numpy(results.as_numpy(f"output__0")) scores = inference_output.detach().numpy() colattr = list(attr_type['labelName']) task = description.split('/')[-1][:-4] maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] attr = { task: [] } for i in range(len(indexs)): atr = colattr[indexs[i]] attr[task].append(atr) return attr # 获取属性检测结果 def get_recognition_attribute_result(self, category, input_img): if category == "Blouse": attr_dict = {} for i in range(len(self.const.top_description_list)): attr_description = self.const.top_description_list[i] attr_model_path = self.const.top_model_list[i] present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) attr_dict = self.merge(attr_dict, present_dict) elif category == 'Trousers' or category == "Skirt": attr_dict = {} for i in range(len(self.const.bottom_description_list)): attr_description = self.const.bottom_description_list[i] attr_model_path = self.const.bottom_model_list[i] present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) attr_dict = self.merge(attr_dict, present_dict) elif category == 'Dress': attr_dict = {} for i in range(len(self.const.dress_description_list)): attr_description = self.const.dress_description_list[i] attr_model_path = self.const.dress_model_list[i] present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) attr_dict = self.merge(attr_dict, present_dict) elif category == 'Outwear': attr_dict = {} for i in range(len(self.const.outwear_description_list)): attr_description = self.const.outwear_description_list[i] attr_model_path = self.const.outwear_model_list[i] present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) attr_dict = self.merge(attr_dict, present_dict) else: attr_dict = {} return attr_dict @staticmethod def merge(dict1, dict2): res = {**dict1, **dict2} return res if __name__ == '__main__': # for path in os.listdir('./test_img'): # img_path = os.path.join(r'./test_img', path) # request_item = BrandDnaModel( # image_url=img_path, # is_brand_dna=True # ) # service = BrandDna(request_item) # result_url = service.get_result() # print(result_url) request_item = BrandDnaModel( image_url="aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png", is_brand_dna=True ) service = BrandDna(request_item) result_url = service.get_result() print(result_url)