关于保存特征的一些代码

This commit is contained in:
zchen
2024-03-26 18:00:35 +08:00
parent 7dce70027a
commit 88c815dd9d
16 changed files with 378 additions and 10 deletions

View File

@@ -327,9 +327,12 @@ class OutfitMaterTypeAware(OutfitMatcher):
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
]
results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
scores = torch.from_numpy(results.as_numpy("output__0"))
return scores # Shape (N, 1)
scores = torch.from_numpy(results.as_numpy("output__0")) # Shape (N, 1)
features = torch.from_numpy(results.as_numpy("output__1")) # Shape (N, 64)
return scores, features

View File

@@ -14,7 +14,14 @@ if __name__ == '__main__':
bad_list = []
for item in param["query"]:
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
scores = service.get_result(outfits)
scores, features = service.get_result(outfits)
# save features
# 链接milvus
# 存入数据库
# 关闭链接
# print(scores)
# print(len(scores))
best_outfits, best_scores = service.visualize(outfits, scores, param["topk"], best=True,

View File

View File

@@ -0,0 +1,102 @@
import io
import json
import numpy as np
import tritonclient.http as httpclient
from PIL import Image
from minio import Minio
from pymilvus import MilvusClient
from app.core.config import *
from torchvision import transforms
class SimilarMatch:
def __init__(self):
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
image_transforms = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
resized_img = image_transforms(img).numpy()
return resized_img
def load_image(self, img_path):
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
# 将 PIL 图像转换为 NumPy 数组
# image_array = np.array(pil_image)
return pil_image
def preprocess(self, img_path):
image = self.load_image(img_path)
image = self.resize_image(image)
image = np.stack([[image]], axis=0)
category = np.stack([[1, 6]], axis=0)
mask = np.zeros((1, 1), dtype=np.float32)
return image, category, mask
def get_features(self, img_path):
image, category, mask = self.preprocess(img_path)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
]
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
features = results.as_numpy("output__1") # Shape (N, 64)
return features
def match_features(self, features):
# 连接milvus
# 连接milvus
client = MilvusClient(uri="http://10.1.1.240:19530", db_name="mixi")
try:
res = client.search(
collection_name="mixi_outfit", # Replace with the actual name of your collection
# Replace with your query vector
data=[features[0]],
limit=5, # Max. number of search results to return
output_fields=["id", "image_path"], # Search parameters
)
return res
finally:
client.close()
if __name__ == '__main__':
service = SimilarMatch()
features = service.get_features(img_path="test/2024 SS/MKTS27000.jpg")
res = service.match_features(features)
print(json.dumps(res, indent=4))