feat generate sketch 新增 服装类别识别
This commit is contained in:
@@ -160,3 +160,41 @@ def stain_detection(image, spot_size=200):
|
||||
cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
|
||||
|
||||
return True, image
|
||||
|
||||
|
||||
def generate_category_recognition(image):
|
||||
def preprocess(img):
|
||||
img = mmcv.imread(img)
|
||||
# ori_shape = img.shape[:2]
|
||||
img_scale = (224, 224)
|
||||
scale_factor = []
|
||||
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
||||
scale_factor.append(x)
|
||||
scale_factor.append(y)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
preprocessed_img = preprocess(image)
|
||||
triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
|
||||
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(preprocessed_img, binary_data=True)
|
||||
results = triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||
|
||||
scores = inference_output.detach().numpy()
|
||||
import pandas as pd
|
||||
|
||||
attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
colattr = list(attr_type['labelName'])
|
||||
|
||||
task = attr_type['taskName'][0]
|
||||
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
category = colattr[indexs[0]]
|
||||
return category, scores, image
|
||||
|
||||
|
||||
Reference in New Issue
Block a user