import io import json import numpy as np import tritonclient.http as httpclient from PIL import Image from minio import Minio from pymilvus import MilvusClient from app.core.config import * from torchvision import transforms class SimilarMatch: def __init__(self): self.minio_client = Minio( f"{MINIO_IP}:{MINIO_PORT}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") @staticmethod def resize_image(img): """ Args: img: ndarray (height, width, channel) """ image_transforms = transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) resized_img = image_transforms(img).numpy() return resized_img def load_image(self, img_path): # 从 MinIO 中获取对象(图像文件) image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1]) # 读取图像数据并转换为 PIL 图像对象 pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB") # 将 PIL 图像转换为 NumPy 数组 # image_array = np.array(pil_image) return pil_image def preprocess(self, img_path): image = self.load_image(img_path) image = self.resize_image(image) image = np.stack([[image]], axis=0) category = np.stack([[1, 6]], axis=0) mask = np.zeros((1, 1), dtype=np.float32) return image, category, mask def get_features(self, img_path): image, category, mask = self.preprocess(img_path) # 输入集 inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32"), httpclient.InferInput("input__1", category.shape, datatype="INT16"), httpclient.InferInput("input__2", mask.shape, datatype="FP32"), ] inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True) inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True) # 输出集 outputs = [ httpclient.InferRequestedOutput("output__0", binary_data=True), httpclient.InferRequestedOutput("output__1", binary_data=True) ] results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) # 推理 # 取结果 features = results.as_numpy("output__1") # Shape (N, 64) return features def match_features(self, features): # 连接milvus # 连接milvus client = MilvusClient(uri="http://10.1.1.240:19530", db_name="mixi") try: res = client.search( collection_name="mixi_outfit", # Replace with the actual name of your collection # Replace with your query vector data=[features[0]], limit=5, # Max. number of search results to return output_fields=["id", "image_path"], # Search parameters ) return res finally: client.close() if __name__ == '__main__': service = SimilarMatch() features = service.get_features(img_path="test/2024 SS/MKTS27000.jpg") res = service.match_features(features) print(json.dumps(res, indent=4))