Files
AiDA_Python/app/service/brand_dna/service.py
zcr 863d9287dc
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 参数对齐
(cherry picked from commit ddef6af1cf)
2026-01-26 14:56:49 +08:00

337 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import DESIGN_MODEL_URL, SEG_PRODUCT_MODEL_URL
from app.core.config import settings
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
logger = logging.getLogger()
class BrandDna:
def __init__(self, request_item):
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url=SEG_PRODUCT_MODEL_URL)
self.const = const
# self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
# cv2.imshow("", image)
# cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
mask_image = np.zeros((height, width, 3))
for value in np.unique(mask):
if value == 1:
outwear_img = white_img.copy()
outwear_mask_img = mask_image.copy()
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
# cv2.imshow("", outwear_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
mask_url = ""
img_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 2:
tops_img = white_img.copy()
tops_mask_img = mask_image.copy()
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
# cv2.imshow("", tops_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
# 属性检测
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 3:
bottoms_img = white_img.copy()
bottoms_mask_img = mask_image.copy()
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
# cv2.imshow("", bottoms_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
return result_dict
# 获取product mask
def get_seg_mask(self):
input_image = self.get_image()
input_img, ori_shape = self.seg_product_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("seg_output__0")
mask = self.product_postprocess(inference_output1, ori_shape)[0]
return mask, input_image
# 获取图片
def get_image(self):
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
# return cv2.imread(self.image_url)
# 图片上传
def put_image(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{object_name}.jpg"
except Exception as e:
logger.warning(e)
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# 类别检测后处理
@staticmethod
def product_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
# 类别检测
def recognition_category(self, image):
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(self.attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
# 属性检测
def recognition_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
# 获取属性检测结果
def get_recognition_attribute_result(self, category, input_img):
if category == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Trousers' or category == "Skirt":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
return attr_dict
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
if __name__ == '__main__':
# for path in os.listdir('./test_img'):
# img_path = os.path.join(r'./test_img', path)
# request_item = BrandDnaModel(
# image_url=img_path,
# is_brand_dna=True
# )
# service = BrandDna(request_item)
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
is_brand_dna=True
)
service = BrandDna(request_item)
result_url = service.get_result()
print(result_url)