Files
AiDA_Python/app/service/attribute/service_att_recognition.py
zcr 18024a2d70
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
2025-12-30 16:49:08 +08:00

168 lines
6.9 KiB
Python

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from pprint import pprint
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class AttributeRecognition:
def __init__(self, const, request_data):
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
{
'obj': self.preprocess(self.get_image(sketch.sketch_img_url)),
'category': sketch.category,
'colony': sketch.colony,
'sketch_img_url': sketch.sketch_img_url,
}
)
self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
def get_result(self):
for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Trousers' or sketch['category'] == "Skirt" or sketch['category'] == "Bottoms":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
elif sketch['category'] == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.get_attribute(attr_model_path, attr_description, sketch['obj'])
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
sketch['attr_dict'] = attr_dict
del sketch['obj']
return self.request_data
def get_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
@staticmethod
def get_image(url):
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
if __name__ == '__main__':
data = [
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/6d7d97a7-5a7d-48bd-9e14-b51119b48620.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/f2437141-1104-40a5-bcb9-f436088698bb.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/07af8613-eb2e-44fd-97cb-a97249a5754c.jpg"
},
{
"category": "Blouse",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/bac9fb15-6860-4112-ac97-f0dea079da75.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/11d59844-effa-4590-82f9-9ea382c76126.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/849bf94c-66b8-42f5-8c2e-c1c1f4c8d0e0.jpg"
},
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
}
]
from app.service.attribute.config import local_debug_const
rq_data = [AttributeRecognitionModel(category=d['category'], colony=d['colony'], sketch_img_url=d['sketch_img_url']) for d in data]
server = AttributeRecognition(local_debug_const, rq_data)
pprint(server.get_result())