diff --git a/app/api/api_brand_dna.py b/app/api/api_brand_dna.py new file mode 100644 index 0000000..6b19416 --- /dev/null +++ b/app/api/api_brand_dna.py @@ -0,0 +1,34 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.brand_dna import BrandDnaModel +from app.schemas.response_template import ResponseModel +from app.service.brand_dna.service import BrandDna + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/seg_product") +def image2sketch(request_item: BrandDnaModel): + """ + 创建一个具有以下参数的请求体: + - **image_url**: 提取图片url + - **is_brand_dna**: 是否提取属性 + + 示例参数: + { + "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", + "is_brand_dna": False + } + """ + try: + logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = BrandDna(request_item) + result_url = service.get_result() + except Exception as e: + logger.warning(f"brand dna Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=result_url) diff --git a/app/schemas/brand_dna.py b/app/schemas/brand_dna.py new file mode 100644 index 0000000..c5ae2ab --- /dev/null +++ b/app/schemas/brand_dna.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class BrandDnaModel(BaseModel): + image_url: str + is_brand_dna: bool diff --git a/app/service/brand_dna/service.py b/app/service/brand_dna/service.py new file mode 100644 index 0000000..012e682 --- /dev/null +++ b/app/service/brand_dna/service.py @@ -0,0 +1,335 @@ +import logging + +import cv2 +import mmcv +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient +from minio import Minio + +from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL +from app.schemas.brand_dna import BrandDnaModel +from app.service.attribute.config import local_debug_const +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image, oss_get_image + +logger = logging.getLogger() + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class BrandDna: + def __init__(self, request_item): + self.sketch_bucket = "test" + self.image_url = request_item.image_url + self.is_brand_dna = request_item.is_brand_dna + # self.attr_type = pd.read_csv(CATEGORY_PATH) + self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv") + self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000') + # self.const = const + self.const = local_debug_const + + # 获取结果 + def get_result(self): + mask, image = self.get_seg_mask() + cv2.imshow("", image) + cv2.waitKey(0) + + height, width, channels = image.shape + result_dict = [] + white_img = np.ones((height, width, channels), dtype=image.dtype) * 255 + mask_image = np.zeros((height, width, 3)) + + for value in np.unique(mask): + if value == 1: + outwear_img = white_img.copy() + outwear_mask_img = mask_image.copy() + outwear_img[mask == value] = image[mask == value] + outwear_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", outwear_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(outwear_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + mask_url = "" + img_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(outwear_img, f"img/{generate_uuid()}") + mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 2: + tops_img = white_img.copy() + tops_mask_img = mask_image.copy() + tops_img[mask == value] = image[mask == value] + tops_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", tops_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(tops_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + # 属性检测 + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}") + img_url = self.put_image(tops_img, f"img/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 3: + bottoms_img = white_img.copy() + bottoms_mask_img = mask_image.copy() + bottoms_img[mask == value] = image[mask == value] + bottoms_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", bottoms_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(bottoms_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}") + mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + return result_dict + + # 获取product mask + def get_seg_mask(self): + input_image = self.get_image() + input_img, ori_shape = self.seg_product_preprocess(input_image) + transformed_img = input_img.astype(np.float32) + + inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)] + results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs) + inference_output1 = results.as_numpy("seg_output__0") + mask = self.product_postprocess(inference_output1, ori_shape)[0] + return mask, input_image + + # 获取图片 + def get_image(self): + image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + # 将其转换为彩色图像 + if len(image.shape) == 3 and image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + elif len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + return image + # return cv2.imread(self.image_url) + + # 图片上传 + def put_image(self, image, object_name): + try: + image_bytes = cv2.imencode('.jpg', image)[1].tobytes() + oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes) + return f"{self.sketch_bucket}/{object_name}.jpg" + except Exception as e: + logger.warning(e) + + # 服装分割预处理 + @staticmethod + def seg_product_preprocess(image): + img = mmcv.imread(image) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + # 类别检测后处理 + @staticmethod + def product_postprocess(output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + # 类别检测模型预处理 + @staticmethod + def category_preprocess(img): + img = mmcv.imread(img) + # ori_shape = img.shape[:2] + img_scale = (224, 224) + img = cv2.resize(img, img_scale) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img + + # 类别检测 + def recognition_category(self, image): + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f'output__0')) + + scores = inference_output.detach().numpy() + + colattr = list(self.attr_type['labelName']) + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + + return colattr[indexs[0]] + + # 属性检测 + def recognition_attribute(self, model_name, description, image): + attr_type = pd.read_csv(description) + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name=model_name, inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f"output__0")) + scores = inference_output.detach().numpy() + colattr = list(attr_type['labelName']) + task = description.split('/')[-1][:-4] + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + attr = { + task: [] + } + for i in range(len(indexs)): + atr = colattr[indexs[i]] + attr[task].append(atr) + return attr + + # 获取属性检测结果 + def get_recognition_attribute_result(self, category, input_img): + if category == "Blouse": + attr_dict = {} + for i in range(len(self.const.top_description_list)): + attr_description = self.const.top_description_list[i] + attr_model_path = self.const.top_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Trousers' or category == "Skirt": + attr_dict = {} + for i in range(len(self.const.bottom_description_list)): + attr_description = self.const.bottom_description_list[i] + attr_model_path = self.const.bottom_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Dress': + attr_dict = {} + for i in range(len(self.const.dress_description_list)): + attr_description = self.const.dress_description_list[i] + attr_model_path = self.const.dress_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Outwear': + attr_dict = {} + for i in range(len(self.const.outwear_description_list)): + attr_description = self.const.outwear_description_list[i] + attr_model_path = self.const.outwear_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + else: + attr_dict = {} + return attr_dict + + @staticmethod + def merge(dict1, dict2): + res = {**dict1, **dict2} + return res + + +if __name__ == '__main__': + # for path in os.listdir('./test_img'): + # img_path = os.path.join(r'./test_img', path) + # request_item = BrandDnaModel( + # image_url=img_path, + # is_brand_dna=True + # ) + # service = BrandDna(request_item) + # result_url = service.get_result() + # print(result_url) + request_item = BrandDnaModel( + image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png", + is_brand_dna=True + ) + service = BrandDna(request_item) + result_url = service.get_result() + print(result_url)