Files
AiDA_Python/app/service/brand_dna/service.py

336 lines
14 KiB
Python
Raw Normal View History

2024-12-20 09:48:10 +08:00
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
2024-12-20 09:48:10 +08:00
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class BrandDna:
def __init__(self, request_item):
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
2024-12-20 09:48:10 +08:00
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
# self.const = const
self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
cv2.imshow("", image)
cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
mask_image = np.zeros((height, width, 3))
for value in np.unique(mask):
if value == 1:
outwear_img = white_img.copy()
outwear_mask_img = mask_image.copy()
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", outwear_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
mask_url = ""
img_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 2:
tops_img = white_img.copy()
tops_mask_img = mask_image.copy()
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", tops_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
# 属性检测
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 3:
bottoms_img = white_img.copy()
bottoms_mask_img = mask_image.copy()
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", bottoms_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
return result_dict
# 获取product mask
def get_seg_mask(self):
input_image = self.get_image()
input_img, ori_shape = self.seg_product_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("seg_output__0")
mask = self.product_postprocess(inference_output1, ori_shape)[0]
return mask, input_image
# 获取图片
def get_image(self):
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
# return cv2.imread(self.image_url)
# 图片上传
def put_image(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{object_name}.jpg"
except Exception as e:
logger.warning(e)
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# 类别检测后处理
@staticmethod
def product_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
# 类别检测
def recognition_category(self, image):
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(self.attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
# 属性检测
def recognition_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
# 获取属性检测结果
def get_recognition_attribute_result(self, category, input_img):
if category == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Trousers' or category == "Skirt":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
return attr_dict
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
if __name__ == '__main__':
# for path in os.listdir('./test_img'):
# img_path = os.path.join(r'./test_img', path)
# request_item = BrandDnaModel(
# image_url=img_path,
# is_brand_dna=True
# )
# service = BrandDna(request_item)
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
is_brand_dna=True
)
service = BrandDna(request_item)
result_url = service.get_result()
print(result_url)