diff --git a/app/api/api_clothing_seg.py b/app/api/api_clothing_seg.py new file mode 100644 index 0000000..e09b882 --- /dev/null +++ b/app/api/api_clothing_seg.py @@ -0,0 +1,51 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.response_template import ResponseModel +from app.schemas.clothing_seg import ClothingSegModel +from app.service.clothing_seg.service import ClothingSeg + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/clothing_seg") +def clothing_seg(request_item: ClothingSegModel): + """ + 创建一个具有以下参数的请求体: + - **user_id**: 用户id + - **image_data**: 图片数据 + { + "image_url": "test/clothing_seg/dress.jpg", + "image_type": "product" + } + + 示例参数: + { + "user_id": 89, + "image_data": [ + { + "image_url": "test/clothing_seg/dress.jpg", + "image_type": "sketch" + }, + { + "image_url": "test/clothing_seg/skirt_559.jpg", + "image_type": "sketch" + }, + { + "image_url": "test/clothing_seg/10144613.jpg", + "image_type": "product" + } + ] + } + """ + try: + logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}") + server = ClothingSeg(request_item) + result_url = server.get_result() + except Exception as e: + logger.warning(f"clothing_seg Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=result_url) diff --git a/app/api/api_route.py b/app/api/api_route.py index 9858ba6..47a4caf 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import api_agent_generate_image, api_recommendation +from app.api import api_agent_generate_image from app.api import api_attribute_retrieve, api_query_image from app.api import api_brand_dna from app.api import api_brighten @@ -12,6 +12,7 @@ from app.api import api_image2sketch from app.api import api_mannequins_edit from app.api import api_pose_transform from app.api import api_prompt_generation +from app.api import api_clothing_seg from app.api import api_super_resolution from app.api import api_test @@ -29,7 +30,8 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") -router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") +# router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") +router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") diff --git a/app/schemas/clothing_seg.py b/app/schemas/clothing_seg.py new file mode 100644 index 0000000..234402c --- /dev/null +++ b/app/schemas/clothing_seg.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ClothingSegModel(BaseModel): + user_id: str + image_data: list[dict] diff --git a/app/service/clothing_seg/service.py b/app/service/clothing_seg/service.py new file mode 100644 index 0000000..a0f3640 --- /dev/null +++ b/app/service/clothing_seg/service.py @@ -0,0 +1,156 @@ +import io +import time +from pprint import pprint + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from PIL import Image +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.clothing_seg import ClothingSegModel +from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.decorator import RunTime +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_get_image, oss_upload_image + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class ClothingSeg: + def __init__(self, request_data): + self.image_data = request_data.image_data + self.user_id = request_data.user_id + self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071") + + @RunTime + def get_result(self): + self.read_image() + self.clothing_seg() + self.upload_image() + for data in self.image_data: + del data["image"] + del data["clothing"] + + return self.image_data + + @RunTime + def upload_image(self): + for data in self.image_data: + data["clothing_url"] = [] + for clothing in data["clothing"]: + object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png" + image_data = io.BytesIO() + clothing.save(image_data, format="PNG") + image_data.seek(0) + image_bytes = image_data.read() + oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes) + data["clothing_url"].append(f"aida-users/{object_name}") + + @RunTime + def read_image(self): + for data in self.image_data: + url = data["image_url"] + image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") + data["image"] = image + + @RunTime + def clothing_seg(self): + for data in self.image_data: + image_type = data["image_type"] + image = data["image"] + clothing_result = [] + if image_type == "sketch": + seg_mask = get_seg_result(1, image) + temp = seg_mask != 0.0 + mask = (255 * (temp + 0).astype(np.uint8)) + x_min, y_min, x_max, y_max = get_bounding_box(mask) + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + h, w = cropped_image.shape[:2] + mask_pil = Image.fromarray(cropped_mask).convert("L") + image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) + transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0)) + transparent_image.paste(image_pil, (0, 0), mask=mask_pil) + clothing_result.append(transparent_image) + else: + input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + input0_data = [input_image.astype(np.float32)] * 1 + input0_data = np.array(input0_data, dtype=np.float32) + inputs = [ + grpcclient.InferInput( + "INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype) + ), + ] + + inputs[0].set_data_from_numpy(input0_data) + + outputs = [ + grpcclient.InferRequestedOutput("OUTPUT0"), + grpcclient.InferRequestedOutput("OUTPUT1"), + ] + response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs) + output0_data = response.as_numpy("OUTPUT0") + cv2.imwrite("output02.png", output0_data * 100) + output1_data = response.as_numpy("OUTPUT1") + for alpha in output1_data: + x_min, y_min, x_max, y_max = get_bounding_box(alpha) + cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + h, w = cropped_image.shape[:2] + mask_pil = Image.fromarray(cropped_mask).convert("L") + image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) + transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0)) + transparent_image.paste(image_pil, (0, 0), mask=mask_pil) + clothing_result.append(transparent_image) + data["clothing"] = clothing_result + + +@RunTime +def get_bounding_box(mask): + """ + 从仅包含 0 和 1 的掩码图像中获取边界框。 + + :param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1 + :return: 边界框坐标 (x_min, y_min, x_max, y_max) + """ + # 找到所有值不为 0 的像素的坐标 + rows, cols = np.where(mask != 0) + + if len(rows) == 0 or len(cols) == 0: + # 如果没有找到不为 0 的像素,返回全 0 的边界框 + return (0, 0, 0, 0) + + # 计算边界框的坐标 + x_min = np.min(cols) + y_min = np.min(rows) + x_max = np.max(cols) + y_max = np.max(rows) + + return (x_min, y_min, x_max, y_max) + + +if __name__ == "__main__": + request_data = ClothingSegModel( + user_id=89, + image_data=[ + { + "image_url": "test/clothing_seg/dress.jpg", + "image_type": "sketch" + }, + { + "image_url": "test/clothing_seg/skirt_559.jpg", + "image_type": "sketch" + }, + { + "image_url": "test/clothing_seg/10144613.jpg", + "image_type": "product" + } + ] + ) + start_time = time.time() + server = ClothingSeg(request_data) + pprint(server.get_result()) + print(time.time() - start_time)