Files
sora_python/app/service/similar_match/service.py
2024-03-27 13:17:41 +08:00

106 lines
3.7 KiB
Python

import io
import json
import numpy as np
import tritonclient.http as httpclient
from PIL import Image
from minio import Minio
from pymilvus import MilvusClient
from app.core.config import *
from torchvision import transforms
from app.service.utils.decorator import RunTime
class SimilarMatch:
def __init__(self):
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
image_transforms = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
resized_img = image_transforms(img).numpy()
return resized_img
def load_image(self, img_path):
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
# 将 PIL 图像转换为 NumPy 数组
# image_array = np.array(pil_image)
return pil_image
def preprocess(self, img_path):
image = self.load_image(img_path)
image = self.resize_image(image)
image = np.stack([[image]], axis=0)
category = np.stack([[1, 6]], axis=0)
mask = np.zeros((1, 1), dtype=np.float32)
return image, category, mask
def get_features(self, img_path):
image, category, mask = self.preprocess(img_path)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
]
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
features = results.as_numpy("output__1") # Shape (N, 64)
return features
@RunTime
def match_features(self, features):
# 连接milvus
# 连接milvus
client = MilvusClient(uri="http://10.1.1.240:19530", db_name="mixi")
try:
search_response = client.search(
collection_name="mixi_outfit", # Replace with the actual name of your collection
# Replace with your query vector
data=[features[0]],
limit=5, # Max. number of search results to return
output_fields=["id", "image_path"], # Search parameters
)
return search_response
finally:
client.close()
if __name__ == '__main__':
service = SimilarMatch()
features = service.get_features(img_path="test/2024 SS/MKTS27000.jpg")
search_response = service.match_features(features)
print(json.dumps(search_response, indent=4))