Files
sora_python/app/service/similar_match/service.py
2024-03-28 10:30:18 +08:00

115 lines
4.1 KiB
Python

import io
import json
from pprint import pprint
import numpy as np
import tritonclient.http as httpclient
from PIL import Image
from minio import Minio
from pymilvus import MilvusClient
from app.core.config import *
from torchvision import transforms
from app.schemas.similar_match import SimilarMatchMItem
from app.service.utils.decorator import RunTime
class SimilarMatch:
def __init__(self, request_data):
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
self.image_path = request_data.image_path
self.result_number = request_data.result_number
self.features = self.get_features()
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
image_transforms = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
resized_img = image_transforms(img).numpy()
return resized_img
def load_image(self, img_path):
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
# 将 PIL 图像转换为 NumPy 数组
# image_array = np.array(pil_image)
return pil_image
def preprocess(self, img_path):
image = self.load_image(img_path)
image = self.resize_image(image)
image = np.stack([[image]], axis=0)
category = np.stack([[1, 6]], axis=0)
mask = np.zeros((1, 1), dtype=np.float32)
return image, category, mask
def get_features(self):
image, category, mask = self.preprocess(self.image_path)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
]
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
features = results.as_numpy("output__1") # Shape (N, 64)
return features
@RunTime
def match_features(self):
# 连接milvus
# 连接milvus
client = MilvusClient(uri="http://10.1.1.240:19530", db_name="mixi")
try:
search_response = client.search(
collection_name="mixi_outfit", # Replace with the actual name of your collection
# Replace with your query vector
data=[self.features[0]],
limit=self.result_number, # Max. number of search results to return
output_fields=["id", "image_path"], # Search parameters
)
return search_response
finally:
client.close()
if __name__ == '__main__':
request_data = SimilarMatchMItem(image_path="test/top/test_top1.jpg", result_number=10)
service = SimilarMatch(request_data)
search_response = service.match_features()
response_data = []
for response in search_response[0]:
response_data.append(response['entity'])
pprint(response_data)