attribute 字段名规范
This commit is contained in:
0
app/service/similar_match/__init__.py
Normal file
0
app/service/similar_match/__init__.py
Normal file
114
app/service/similar_match/service.py
Normal file
114
app/service/similar_match/service.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import io
|
||||
import json
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.http as httpclient
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from torchvision import transforms
|
||||
|
||||
from app.schemas.similar_match import SimilarMatchMItem
|
||||
from app.service.utils.decorator import RunTime
|
||||
|
||||
|
||||
class SimilarMatch:
|
||||
def __init__(self, request_data):
|
||||
self.minio_client = Minio(
|
||||
f"{MINIO_IP}:{MINIO_PORT}",
|
||||
access_key=MINIO_ACCESS,
|
||||
secret_key=MINIO_SECRET,
|
||||
secure=MINIO_SECURE)
|
||||
self.triton_client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
|
||||
self.image_path = request_data.image_path
|
||||
self.result_number = request_data.result_number
|
||||
self.features = self.get_features()
|
||||
|
||||
@staticmethod
|
||||
def resize_image(img):
|
||||
"""
|
||||
Args:
|
||||
img: ndarray (height, width, channel)
|
||||
"""
|
||||
image_transforms = transforms.Compose([
|
||||
transforms.Resize(112),
|
||||
transforms.CenterCrop(112),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
resized_img = image_transforms(img).numpy()
|
||||
return resized_img
|
||||
|
||||
def load_image(self, img_path):
|
||||
# 从 MinIO 中获取对象(图像文件)
|
||||
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
|
||||
# 读取图像数据并转换为 PIL 图像对象
|
||||
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
|
||||
# 将 PIL 图像转换为 NumPy 数组
|
||||
# image_array = np.array(pil_image)
|
||||
return pil_image
|
||||
|
||||
def preprocess(self, img_path):
|
||||
image = self.load_image(img_path)
|
||||
image = self.resize_image(image)
|
||||
image = np.stack([[image]], axis=0)
|
||||
|
||||
category = np.stack([[1, 6]], axis=0)
|
||||
|
||||
mask = np.zeros((1, 1), dtype=np.float32)
|
||||
return image, category, mask
|
||||
|
||||
def get_features(self):
|
||||
image, category, mask = self.preprocess(self.image_path)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
|
||||
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
|
||||
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
|
||||
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
|
||||
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput("output__0", binary_data=True),
|
||||
httpclient.InferRequestedOutput("output__1", binary_data=True)
|
||||
]
|
||||
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
features = results.as_numpy("output__1") # Shape (N, 64)
|
||||
return features
|
||||
|
||||
@RunTime
|
||||
def match_features(self):
|
||||
# 连接milvus
|
||||
# 连接milvus
|
||||
client = MilvusClient(uri="http://10.1.1.240:19530", db_name="mixi")
|
||||
try:
|
||||
search_response = client.search(
|
||||
collection_name="mixi_outfit", # Replace with the actual name of your collection
|
||||
# Replace with your query vector
|
||||
data=[self.features[0]],
|
||||
limit=self.result_number, # Max. number of search results to return
|
||||
output_fields=["id", "image_path"], # Search parameters
|
||||
)
|
||||
return search_response
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = SimilarMatchMItem(image_path="test/top/test_top1.jpg", result_number=10)
|
||||
|
||||
service = SimilarMatch(request_data)
|
||||
search_response = service.match_features()
|
||||
response_data = []
|
||||
for response in search_response[0]:
|
||||
response_data.append(response['entity'])
|
||||
|
||||
pprint(response_data)
|
||||
Reference in New Issue
Block a user