Files
AiDA_Python/app/service/attribute/service_att_recognition.py
zcr c03b7e263e
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:17:31 +08:00

167 lines
6.9 KiB
Python

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from pprint import pprint
import cv2
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class AttributeRecognition:
def __init__(self, const, request_data):
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
{
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
'category': sketch.category,
'colony': sketch.colony,
'sketch_img_url': sketch.sketch_img_url,
}
)
self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
def get_result(self):
for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
sketch['attr_dict'] = attr_dict
del sketch['obj']
return self.request_data
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
@staticmethod
def get_image(url):
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@staticmethod
def preprocess(img):
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
if __name__ == '__main__':
data = [
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
},
{
"category": "Blouse",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
}
]
from app.service.attribute.config import local_debug_const
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
server = AttributeRecognition(local_debug_const, rq_data)
pprint(server.get_result())