fix attribute 接口新增类别隐射 不完全依赖模型推理产品类别,以映射文件为主

This commit is contained in:
zhouchengrong
2024-04-08 15:23:23 +08:00
parent fda602b6fe
commit 3960a19014
5 changed files with 749 additions and 151 deletions

View File

@@ -14,6 +14,7 @@ import pandas as pd
from minio import Minio
from app.core.config import MINIO_IP, MINIO_PORT, MINIO_SECURE, MINIO_ACCESS, MINIO_SECRET, ATT_TRITON_IP, ATT_TRITON_PORT
from app.service.attribute_recognition.config.category_mapping import category_mapping
def Merge(dict1, dict2):
@@ -92,22 +93,28 @@ class AttributeRecognition:
sample = self.preprocess(img_path)
category_model_path = args.category_model
category_description = args.category_discription
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
# 如果category为all-body 则通过category获取类别 主要用去区分为连体裤和连衣裙
mapping_category = category_mapping[request_data.upload_img_category[l]]
if mapping_category not in ["tops", "bottoms", "skirt", "dress", "outerwear"]:
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
else:
category_list = [mapping_category]
attr_dict = {}
if len(category_list) >= 1:
category = category_list[0]
print(category)
if category == 'top':
if category == 'tops':
attr_dict = {'Item': category}
for i in range(len(args.top_discription_list)):
# print('top: ', i)
attr_description = args.top_discription_list[i]
attr_model_path = args.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'pants':
elif category == 'bottoms':
attr_dict = {}
category = 'bottom'
attr_dict['Item'] = category
@@ -136,7 +143,7 @@ class AttributeRecognition:
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
elif category == 'outwear':
elif category == 'outerwear':
attr_dict = {'Item': 'outer'}
for i in range(len(args.outwear_discription_list)):
@@ -152,8 +159,6 @@ class AttributeRecognition:
attr_model_path = args.jumpsuit_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
attr_dict = Merge(attr_dict, present_dict)
else:
attr_dict = {}
print('attr_dict: ', attr_dict)
final_dict[request_data.upload_img_id[l]] = attr_dict
@@ -210,6 +215,6 @@ class AttributeRecognition:
if __name__ == '__main__':
from app.service.attribute_recognition import const_debug
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"]}
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"], "update_img_category": ["TOP/ONE PIECE"]}
service = AttributeRecognition()
pprint(service.attribute(const_debug, request_data))