2024-03-28 10:30:18 +08:00
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
from pprint import pprint
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import tritonclient.http as httpclient
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from minio import Minio
|
|
|
|
|
from pymilvus import MilvusClient
|
|
|
|
|
|
|
|
|
|
from app.core.config import *
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
from app.schemas.similar_match import SimilarMatchMItem
|
|
|
|
|
from app.service.utils.decorator import RunTime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimilarMatch:
|
|
|
|
|
def __init__(self, request_data):
|
|
|
|
|
self.minio_client = Minio(
|
|
|
|
|
f"{MINIO_IP}:{MINIO_PORT}",
|
|
|
|
|
access_key=MINIO_ACCESS,
|
|
|
|
|
secret_key=MINIO_SECRET,
|
|
|
|
|
secure=MINIO_SECURE)
|
|
|
|
|
self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
|
|
|
|
|
self.image_path = request_data.image_path
|
|
|
|
|
self.result_number = request_data.result_number
|
|
|
|
|
self.features = self.get_features()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def resize_image(img):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
img: ndarray (height, width, channel)
|
|
|
|
|
"""
|
|
|
|
|
image_transforms = transforms.Compose([
|
|
|
|
|
transforms.Resize(112),
|
|
|
|
|
transforms.CenterCrop(112),
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
|
std=[0.229, 0.224, 0.225]),
|
|
|
|
|
])
|
|
|
|
|
resized_img = image_transforms(img).numpy()
|
|
|
|
|
return resized_img
|
|
|
|
|
|
|
|
|
|
def load_image(self, img_path):
|
|
|
|
|
# 从 MinIO 中获取对象(图像文件)
|
|
|
|
|
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
|
|
|
|
|
# 读取图像数据并转换为 PIL 图像对象
|
|
|
|
|
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
|
|
|
|
|
# 将 PIL 图像转换为 NumPy 数组
|
|
|
|
|
# image_array = np.array(pil_image)
|
|
|
|
|
return pil_image
|
|
|
|
|
|
|
|
|
|
def preprocess(self, img_path):
|
|
|
|
|
image = self.load_image(img_path)
|
|
|
|
|
image = self.resize_image(image)
|
2024-04-10 14:33:06 +08:00
|
|
|
image = np.stack([image], axis=0)
|
2024-03-28 10:30:18 +08:00
|
|
|
|
2024-04-10 14:33:06 +08:00
|
|
|
# category = np.stack([[1, 6]], axis=0)
|
2024-03-28 10:30:18 +08:00
|
|
|
|
2024-04-10 14:33:06 +08:00
|
|
|
# mask = np.zeros((1, 1), dtype=np.float32)
|
|
|
|
|
return image
|
|
|
|
|
# , category, mask)
|
2024-03-28 10:30:18 +08:00
|
|
|
|
|
|
|
|
def get_features(self):
|
2024-04-10 14:33:06 +08:00
|
|
|
image = self.preprocess(self.image_path)
|
|
|
|
|
# image, category, mask = self.preprocess(self.image_path)
|
2024-03-28 10:30:18 +08:00
|
|
|
# 输入集
|
|
|
|
|
inputs = [
|
|
|
|
|
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
|
|
|
|
|
]
|
|
|
|
|
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
|
|
|
|
|
# 输出集
|
|
|
|
|
outputs = [
|
|
|
|
|
httpclient.InferRequestedOutput("output__0", binary_data=True),
|
|
|
|
|
]
|
2024-04-10 14:33:06 +08:00
|
|
|
results = self.triton_client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs)
|
2024-03-28 10:30:18 +08:00
|
|
|
# 推理
|
|
|
|
|
# 取结果
|
2024-04-10 14:33:06 +08:00
|
|
|
features = results.as_numpy("output__0") # Shape (N, 64)
|
2024-03-28 10:30:18 +08:00
|
|
|
return features
|
|
|
|
|
|
|
|
|
|
@RunTime
|
|
|
|
|
def match_features(self):
|
|
|
|
|
# 连接milvus
|
2024-04-03 15:14:54 +08:00
|
|
|
client = MilvusClient(uri=MILVUS_URL, db_name="mixi")
|
2024-03-28 10:30:18 +08:00
|
|
|
try:
|
|
|
|
|
search_response = client.search(
|
|
|
|
|
collection_name="mixi_outfit", # Replace with the actual name of your collection
|
|
|
|
|
# Replace with your query vector
|
|
|
|
|
data=[self.features[0]],
|
|
|
|
|
limit=self.result_number, # Max. number of search results to return
|
2024-04-10 14:33:06 +08:00
|
|
|
output_fields=["item_name", "image_path"], # Search parameters
|
2024-03-28 10:30:18 +08:00
|
|
|
)
|
|
|
|
|
return search_response
|
|
|
|
|
finally:
|
|
|
|
|
client.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
request_data = SimilarMatchMItem(image_path="test/top/test_top1.jpg", result_number=10)
|
|
|
|
|
|
|
|
|
|
service = SimilarMatch(request_data)
|
|
|
|
|
search_response = service.match_features()
|
|
|
|
|
response_data = []
|
|
|
|
|
for response in search_response[0]:
|
|
|
|
|
response_data.append(response['entity'])
|
|
|
|
|
|
|
|
|
|
pprint(response_data)
|