import io import time from pprint import pprint import cv2 import numpy as np import tritonclient.grpc as grpcclient from PIL import Image from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import settings from app.schemas.clothing_seg import ClothingSegModel from app.service.design_fast.utils.design_ensemble import get_seg_result from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid from app.service.utils.new_oss_client import oss_get_image, oss_upload_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class ClothingSeg: def __init__(self, request_data): self.image_data = request_data.image_data self.user_id = request_data.user_id self.triton_client = grpcclient.InferenceServerClient(url=f"{settings.B_4_X_4090_SERVICE_HOST}:10071") @RunTime def get_result(self): self.read_image() self.clothing_seg() self.upload_image() for data in self.image_data: del data["image"] del data["clothing"] return self.image_data @RunTime def upload_image(self): for data in self.image_data: data["clothing_url"] = [] for clothing in data["clothing"]: object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png" image_data = io.BytesIO() clothing.save(image_data, format="PNG") image_data.seek(0) image_bytes = image_data.read() oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes) data["clothing_url"].append(f"aida-users/{object_name}") @RunTime def read_image(self): for data in self.image_data: url = data["image_url"] image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") data["image"] = image @RunTime def clothing_seg(self): for data in self.image_data: image_type = data["image_type"] image = data["image"] clothing_result = [] if image_type == "sketch": if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) seg_mask = get_seg_result(image[:, :, :3]) else: seg_mask = get_seg_result(image[:, :, :3]) temp = seg_mask != 0.0 mask = (255 * (temp + 0).astype(np.uint8)) x_min, y_min, x_max, y_max = get_bounding_box(mask) cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] h, w = cropped_image.shape[:2] mask_pil = Image.fromarray(cropped_mask).convert("L") image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0)) transparent_image.paste(image_pil, (0, 0), mask=mask_pil) clothing_result.append(transparent_image) else: input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) input0_data = [input_image.astype(np.float32)] * 1 input0_data = np.array(input0_data, dtype=np.float32) inputs = [ grpcclient.InferInput( "INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype) ), ] inputs[0].set_data_from_numpy(input0_data) outputs = [ # grpcclient.InferRequestedOutput("OUTPUT0"), grpcclient.InferRequestedOutput("OUTPUT1"), ] response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs) # output0_data = response.as_numpy("OUTPUT0") # cv2.imwrite("output02.png", output0_data * 100) output1_data = response.as_numpy("OUTPUT1") for alpha in output1_data: alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC) x_min, y_min, x_max, y_max = get_bounding_box(alpha) cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1] cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] h, w = cropped_image.shape[:2] mask_pil = Image.fromarray(cropped_mask).convert("L") image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0)) transparent_image.paste(image_pil, (0, 0), mask=mask_pil) clothing_result.append(transparent_image) data["clothing"] = clothing_result @RunTime def get_bounding_box(mask): """ 从仅包含 0 和 1 的掩码图像中获取边界框。 :param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1 :return: 边界框坐标 (x_min, y_min, x_max, y_max) """ # 找到所有值不为 0 的像素的坐标 rows, cols = np.where(mask != 0) if len(rows) == 0 or len(cols) == 0: # 如果没有找到不为 0 的像素,返回全 0 的边界框 return 0, 0, 0, 0 # 计算边界框的坐标 x_min = np.min(cols) y_min = np.min(rows) x_max = np.max(cols) y_max = np.max(rows) return x_min, y_min, x_max, y_max if __name__ == "__main__": test_data = ClothingSegModel( user_id="89", image_data=[ # { # "image_url": "test/clothing_seg/dress.jpg", # "image_type": "sketch" # }, # { # "image_url": "test/clothing_seg/skirt_559.jpg", # "image_type": "sketch" # }, { "image_url": "aida-collection-element/87/Sketchboard/ab40e035-547a-48c5-9f97-1db7bf56ad77.jpg", "image_type": "sketch" } ] ) start_time = time.time() server = ClothingSeg(test_data) pprint(server.get_result()) print(time.time() - start_time)