fix attribute 接口新增类别隐射 不完全依赖模型推理产品类别,以映射文件为主
This commit is contained in:
@@ -14,6 +14,7 @@ import pandas as pd
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_IP, MINIO_PORT, MINIO_SECURE, MINIO_ACCESS, MINIO_SECRET, ATT_TRITON_IP, ATT_TRITON_PORT
|
||||
from app.service.attribute_recognition.config.category_mapping import category_mapping
|
||||
|
||||
|
||||
def Merge(dict1, dict2):
|
||||
@@ -92,22 +93,28 @@ class AttributeRecognition:
|
||||
sample = self.preprocess(img_path)
|
||||
category_model_path = args.category_model
|
||||
category_description = args.category_discription
|
||||
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
|
||||
|
||||
# 如果category为all-body 则通过category获取类别 , 主要用去区分为连体裤和连衣裙
|
||||
mapping_category = category_mapping[request_data.upload_img_category[l]]
|
||||
if mapping_category not in ["tops", "bottoms", "skirt", "dress", "outerwear"]:
|
||||
category_list = self.get_attribute(category_model_path, category_description, sample)['category']
|
||||
else:
|
||||
category_list = [mapping_category]
|
||||
|
||||
attr_dict = {}
|
||||
if len(category_list) >= 1:
|
||||
category = category_list[0]
|
||||
print(category)
|
||||
|
||||
if category == 'top':
|
||||
if category == 'tops':
|
||||
attr_dict = {'Item': category}
|
||||
for i in range(len(args.top_discription_list)):
|
||||
# print('top: ', i)
|
||||
attr_description = args.top_discription_list[i]
|
||||
attr_model_path = args.top_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
|
||||
attr_dict = Merge(attr_dict, present_dict)
|
||||
|
||||
elif category == 'pants':
|
||||
elif category == 'bottoms':
|
||||
attr_dict = {}
|
||||
category = 'bottom'
|
||||
attr_dict['Item'] = category
|
||||
@@ -136,7 +143,7 @@ class AttributeRecognition:
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
|
||||
attr_dict = Merge(attr_dict, present_dict)
|
||||
|
||||
elif category == 'outwear':
|
||||
elif category == 'outerwear':
|
||||
attr_dict = {'Item': 'outer'}
|
||||
|
||||
for i in range(len(args.outwear_discription_list)):
|
||||
@@ -152,8 +159,6 @@ class AttributeRecognition:
|
||||
attr_model_path = args.jumpsuit_model_list[i]
|
||||
present_dict = self.get_attribute(attr_model_path, attr_description, sample)
|
||||
attr_dict = Merge(attr_dict, present_dict)
|
||||
else:
|
||||
attr_dict = {}
|
||||
|
||||
print('attr_dict: ', attr_dict)
|
||||
final_dict[request_data.upload_img_id[l]] = attr_dict
|
||||
@@ -210,6 +215,6 @@ class AttributeRecognition:
|
||||
if __name__ == '__main__':
|
||||
from app.service.attribute_recognition import const_debug
|
||||
|
||||
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"]}
|
||||
request_data = {'upload_img_path': ['./test_top1.jpg'], 'upload_img_id': ["2"], "update_img_category": ["TOP/ONE PIECE"]}
|
||||
service = AttributeRecognition()
|
||||
pprint(service.attribute(const_debug, request_data))
|
||||
|
||||
Reference in New Issue
Block a user