attribute 模型名称错误
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
from pprint import pprint
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import tritonclient.http as httpclient
|
||||
import torch
|
||||
|
||||
from app.core.config import ATT_TRITON_PORT, ATT_TRITON_IP
|
||||
|
||||
model_list = ['bottom_design', 'bottom_length', 'bottom_material', 'bottom_OPType_B', 'bottom_print', 'bottom_Silhouette_B', 'bottom_Softness_B', 'bottom_sub-Type', 'category', 'dress_collar', 'dress_design', 'dress_length', 'dress_material', 'dress_neckline', 'dress_print', 'dress_silohouette12', 'dress_sleeve_length', 'dress_sleeve_shape', 'dress_sleeve_shoulder', 'dress_softness', 'dress_type', 'jumpsuit_collar', 'jumpsuit_design', 'jumpsuit_length', 'jumpsuit_material', 'jumpsuit_optype',
|
||||
'jumpsuit_print', 'jumpsuit_sleeve_length', 'jumpsuit_sleeve_shape', 'jumpsuit_sleeve_shoulder', 'jumpsuit_softness', 'jumpsuit_subtype', 'outwear_material', 'outwear_outear_length', 'outwear_outer_collar', 'outwear_outer_design', 'outwear_outer_optype', 'outwear_outer_silhouette', 'outwear_outer_sleeve_length', 'outwear_outer_sleeve_shape', 'outwear_outer_sleeve_shoulder', 'outwear_outer_softness', 'outwear_print', 'top_Collar', 'top_Design', 'top_length', 'top_material',
|
||||
'top_Neckline', 'top_optype', 'top_print', 'top_Silhouette', 'top_Sleeve_length', 'top_Sleeve_shape', 'top_Sleeve_shoulder', 'top_Softness', 'top_type']
|
||||
|
||||
|
||||
def preprocess(img):
|
||||
img = mmcv.imread(img)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale = (224, 224)
|
||||
scale_factor = []
|
||||
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
||||
scale_factor.append(x)
|
||||
scale_factor.append(y)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
def get_attribute(model_save_name, sample):
|
||||
triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_IP}:{ATT_TRITON_PORT}")
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", sample.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
||||
results = triton_client.infer(model_name=model_save_name, inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
|
||||
scores = inference_output.detach().numpy()
|
||||
pprint(scores)
|
||||
print(f"{model_save_name} is ok")
|
||||
|
||||
|
||||
image, shape = preprocess(cv2.imread("test_top1.jpg"))
|
||||
except_model = []
|
||||
|
||||
for model in model_list:
|
||||
try:
|
||||
get_attribute(model, image)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except_model.append(model)
|
||||
print(except_model)
|
||||
Reference in New Issue
Block a user