import io import json from pprint import pprint import numpy as np import tritonclient.http as httpclient from PIL import Image from minio import Minio from pymilvus import MilvusClient from app.core.config import * from torchvision import transforms from app.schemas.similar_match import SimilarMatchMItem from app.service.utils.decorator import RunTime class SimilarMatch: def __init__(self, request_data): self.minio_client = Minio( f"{MINIO_IP}:{MINIO_PORT}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") self.image_path = request_data.image_path self.result_number = request_data.result_number self.features = self.get_features() @staticmethod def resize_image(img): """ Args: img: ndarray (height, width, channel) """ image_transforms = transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) resized_img = image_transforms(img).numpy() return resized_img def load_image(self, img_path): # 从 MinIO 中获取对象(图像文件) image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1]) # 读取图像数据并转换为 PIL 图像对象 pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB") # 将 PIL 图像转换为 NumPy 数组 # image_array = np.array(pil_image) return pil_image def preprocess(self, img_path): image = self.load_image(img_path) image = self.resize_image(image) image = np.stack([image], axis=0) # category = np.stack([[1, 6]], axis=0) # mask = np.zeros((1, 1), dtype=np.float32) return image # , category, mask) def get_features(self): image = self.preprocess(self.image_path) # image, category, mask = self.preprocess(self.image_path) # 输入集 inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32"), ] inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) # 输出集 outputs = [ httpclient.InferRequestedOutput("output__0", binary_data=True), ] results = self.triton_client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs) # 推理 # 取结果 features = results.as_numpy("output__0") # Shape (N, 64) return features @RunTime def match_features(self): # 连接milvus client = MilvusClient(uri=MILVUS_URL, db_name="mixi") try: search_response = client.search( collection_name="mixi_outfit", # Replace with the actual name of your collection # Replace with your query vector data=[self.features[0]], limit=self.result_number, # Max. number of search results to return output_fields=["item_name", "image_path"], # Search parameters ) return search_response finally: client.close() if __name__ == '__main__': request_data = SimilarMatchMItem(image_path="test/top/test_top1.jpg", result_number=10) service = SimilarMatch(request_data) search_response = service.match_features() response_data = [] for response in search_response[0]: response_data.append(response['entity']) pprint(response_data)