From 39dae92ea05137e33a2124c42033801a921c8d5c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 28 Mar 2024 14:14:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=90=AD=E9=85=8D=E6=9C=8D=E5=8A=A1=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../outfit_matcher/outfit_evaluator.py | 287 +++++++++++------- 1 file changed, 176 insertions(+), 111 deletions(-) diff --git a/app/service/outfit_matcher/outfit_evaluator.py b/app/service/outfit_matcher/outfit_evaluator.py index 0392414..564c350 100644 --- a/app/service/outfit_matcher/outfit_evaluator.py +++ b/app/service/outfit_matcher/outfit_evaluator.py @@ -13,6 +13,105 @@ from app.service.outfit_matcher.foco import extract_main_colors from app.service.utils.decorator import RunTime +class Backbone(object): + def __init__(self): + self.tritonclient = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") + self.minio_client = Minio( + f"{MINIO_IP}:{MINIO_PORT}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + @RunTime + # TODO 用多线程读图片 + def load_image(self, img_path): + try: + # 从 MinIO 中获取对象(图像文件) + image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1]) + + # 读取图像数据并转换为 PIL 图像对象 + pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB") + + # 将 PIL 图像转换为 NumPy 数组 + # image_array = np.array(pil_image) + + return pil_image + except Exception as e: + print(f"An error occurred: {e}") + return None + + @staticmethod + def resize_image(img): + """ + Args: + img: ndarray (height, width, channel) + """ + image_transforms = transforms.Compose([ + transforms.Resize(112), + transforms.CenterCrop(112), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + resized_img = image_transforms(img).numpy() + return resized_img + + def preprocess(self, items): + images = [] + for item in items: + image = self.load_image(item["image_path"]) + image = self.resize_image(image) + images.append(image) + images = np.stack(images, axis=0) + return images + + @RunTime + def get_result(self, items): + """Input items and output features for similiarity. + Args: + items: images of fashion items + Example: + [ + { + "item_name": "MSE_57987", + "semantic_category": "BOTTOM/PANTS", + "image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MSE_57987.jpg", + "mapped_cate": "bottoms" + }, + { + "item_name": "MPO_SP7712", + "semantic_category": "TOP/TANK", + "image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MPO_SP7712.jpg", + "mapped_cate": "tops" + }, + { + "item_name": "MWSS27195", + "semantic_category": "OUTERWEAR/GILET", + "image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MWSS27195.jpg", + "mapped_cate": "outerwear" + } + ] + Returns: + scores: List of image features + """ + image = self.preprocess(items) + client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") + # 输入集 + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32"), + ] + inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput("output__0", binary_data=True), + ] + results = client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + features = results.as_numpy("output__0") # Shape (N, 64) + return features + + class OutfitMatcher(object): def __init__(self): self.tritonclient = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") @@ -22,6 +121,22 @@ class OutfitMatcher(object): secret_key=MINIO_SECRET, secure=MINIO_SECURE) + def load_image(self, img_path): + try: + # 从 MinIO 中获取对象(图像文件) + image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1]) + + # 读取图像数据并转换为 PIL 图像对象 + pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB") + + # 将 PIL 图像转换为 NumPy 数组 + # image_array = np.array(pil_image) + + return pil_image + except Exception as e: + print(f"An error occurred: {e}") + return None + @staticmethod def pad_array(input_value, value=0): """pad List of Array into same batch size @@ -77,51 +192,48 @@ class OutfitMatcher(object): @RunTime def visualize(self, outfits, scores, topk=5, best=True, output_path=None): # 将outfits和scores按照scores的值进行排序 - sorted_indices = np.argsort(-scores.flatten() if best else scores.flatten())[:topk] # 使用负号进行降序排序 + # sorted_indices = np.argsort(-scores if best else scores)[:topk] # for HON + sorted_indices = np.argsort(scores if best else -scores)[:topk] # type-aware outfits = [outfits[i] for i in sorted_indices] # 最好或最差的五个 scores = scores[sorted_indices] # 这五个的分数 - # 是否画出来 + # 设置子图的行列数 + num_rows = len(outfits) + num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数 + + # 创建一个新的图像,并指定子图的行列数 + fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15)) + + title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits" + fig.suptitle(title, fontsize=16) + + # 遍历每套outfit并将其显示在对应的子图中 + for i, (outfit, score) in enumerate(zip(outfits, scores)): + # 显示分数 + axes[i, 0].text(0.1, 0.5, f"Score: {score:.4f}", fontsize=12) + axes[i, 0].axis("off") + # 显示图片 + for j, item in enumerate(outfit): + img = self.load_image(item['image_path']) # 读取图片 + axes[i, j + 1].imshow(img) # 在对应的子图中显示图片 + axes[i, j + 1].axis('off') # 关闭坐标轴 + axes[i, j + 1].set_title(item["semantic_category"], fontsize=10) + for j in range(len(outfit), num_cols): + axes[i, j].axis("off") + + # 在每一行的底部添加一条横线 + axes[i, 0].axhline(y=0, color='black', linewidth=1) + # 隐藏最后一行的横线 + axes[-1, 0].axhline(y=0, color='white', linewidth=1) + + # 调整布局 + plt.subplots_adjust(wspace=0.1, hspace=0.1) + plt.tight_layout() + if output_path: - # 设置子图的行列数 - num_rows = len(outfits) - num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数 - - # 创建一个新的图像,并指定子图的行列数 - fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15)) - - title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits" - fig.suptitle(title, fontsize=16) - - # 遍历每套outfit并将其显示在对应的子图中 - for i, (outfit, score) in enumerate(zip(outfits, scores)): - # 显示分数 - axes[i, 0].text(0.1, 0.5, f"Score: {score[0]:.4f}", fontsize=12) - axes[i, 0].axis("off") - # 显示图片 - for j, item in enumerate(outfit): - img = mpimg.imread(item['image_path']) # 读取图片 - axes[i, j + 1].imshow(img) # 在对应的子图中显示图片 - axes[i, j + 1].axis('off') # 关闭坐标轴 - axes[i, j + 1].set_title(item["semantic_category"], fontsize=10) - for j in range(len(outfit), num_cols): - axes[i, j].axis("off") - - # 在每一行的底部添加一条横线 - axes[i, 0].axhline(y=0, color='black', linewidth=1) - # 隐藏最后一行的横线 - axes[-1, 0].axhline(y=0, color='white', linewidth=1) - - # 调整布局 - plt.subplots_adjust(wspace=0.1, hspace=0.1) - plt.tight_layout() - - if output_path: - plt.savefig(output_path) - else: - plt.show() + plt.savefig(output_path) else: - return outfits, scores.numpy().flatten().tolist() + plt.show() class OutfitMatcherHon(OutfitMatcher): @@ -216,60 +328,17 @@ class OutfitMaterTypeAware(OutfitMatcher): 'outerwear', 'scarves', 'shoes', 'sunglasses', 'tops' ] - @RunTime def __init__(self): super().__init__() - @RunTime - # TODO 用多线程读图片 - def load_image(self, img_path): - try: - # 从 MinIO 中获取对象(图像文件) - image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1]) - - # 读取图像数据并转换为 PIL 图像对象 - pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB") - - # 将 PIL 图像转换为 NumPy 数组 - # image_array = np.array(pil_image) - - return pil_image - except Exception as e: - print(f"An error occurred: {e}") - return None - # if 'http' in img_path: - # file = requests.get(img_path) - # image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1) - # image = Image.fromarray(image.astype('uint8'), 'RGB') - # else: - # image = Image.open(img_path).convert('RGB') - # return np.array(image) - - @staticmethod - def resize_image(img): - """ - Args: - img: ndarray (height, width, channel) - """ - image_transforms = transforms.Compose([ - transforms.Resize(112), - transforms.CenterCrop(112), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - resized_img = image_transforms(img).numpy() - return resized_img - - def preprocess(self, outfits): + def preprocess(self, outfits, features): outfit_images = [] outfit_categories = [] for outfit in outfits: images = [] categories = [] for item in outfit: - image = self.load_image(item["image_path"]) - image = self.resize_image(image) + image = features[item["item_name"]] images.append(image) category = self.base_fashion_categories.index(item["mapped_cate"]) categories.append(category) @@ -277,12 +346,10 @@ class OutfitMaterTypeAware(OutfitMatcher): outfit_images.append(images) # List[(items, 3, 224, 224)] categories = np.array(categories) outfit_categories.append(categories) # List[(items)] - outfit_images, mask = self.pad_array(outfit_images, value=0) - outfit_categories, _ = self.pad_array(outfit_categories, value=len(self.base_fashion_categories)) - return outfit_images, outfit_categories, mask + return outfit_images, outfit_categories @RunTime - def get_result(self, outfits): + def get_result(self, outfits, features): """Input outfits structure and output scores. Args: outfits: outfits to be evaluated. @@ -310,29 +377,27 @@ class OutfitMaterTypeAware(OutfitMatcher): ], ... ] + features: dict(item_name = np.Array) image features of those items Returns: scores: List of float """ - image, category, mask = self.preprocess(outfits) - client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") - # 输入集 - inputs = [ - httpclient.InferInput("input__0", image.shape, datatype="FP32"), - httpclient.InferInput("input__1", category.shape, datatype="INT16"), - httpclient.InferInput("input__2", mask.shape, datatype="FP32"), - ] - inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) - inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True) - inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True) - # 输出集 - outputs = [ - httpclient.InferRequestedOutput("output__0", binary_data=True), - httpclient.InferRequestedOutput("output__1", binary_data=True) - ] - results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) - # 推理 - # 取结果 - scores = torch.from_numpy(results.as_numpy("output__0")) # Shape (N, 1) - features = torch.from_numpy(results.as_numpy("output__1")) # Shape (N, 64) + outfit_images, outfit_categories = self.preprocess(outfits, features) + scores = [] + for images, categories in zip(outfit_images, outfit_categories): + client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}") + # 输入集 + inputs = [ + httpclient.InferInput("input__0", images.shape, datatype="FP32"), + httpclient.InferInput("input__1", categories.shape, datatype="INT16") + ] + inputs[0].set_data_from_numpy(images.astype(np.float32), binary_data=True) + inputs[1].set_data_from_numpy(categories.astype(np.int16), binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput("output__0", binary_data=True), + ] + results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) + scores.append(results.as_numpy("output__0")) # Shape (N, 1) - return scores, features + scores = np.stack(scores, axis=0) + return scores.flatten()