From ae5260895135cf8afa3fc08ce01d2a611c44182a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 23 Apr 2024 08:32:14 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20=E6=96=B0=E5=A2=9E=E8=83=8C?= =?UTF-8?q?=E6=99=AF=E5=8E=BB=E9=99=A4=20bounding=20box?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service.py | 3 +++ .../generate_image/utils/remove_background.py | 22 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index d0bda9a..db1868e 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.utils.remove_background import remove_background from app.service.generate_image.utils.upload_sd_image import upload_png_sd logger = logging.getLogger() @@ -71,6 +72,8 @@ class GenerateImage: self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: image_result = result.as_numpy("generated_image")[0] + if self.category == "sketch": + image_result = remove_background(np.asarray(image_result)) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" diff --git a/app/service/generate_image/utils/remove_background.py b/app/service/generate_image/utils/remove_background.py index 138bf2f..d27cf97 100644 --- a/app/service/generate_image/utils/remove_background.py +++ b/app/service/generate_image/utils/remove_background.py @@ -109,4 +109,24 @@ def remove_background(image): white_background = np.ones_like(image_obj) * 255 result_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background) - return Image.fromarray(result_image) + import cv2 + + edges = cv2.Canny(result_image, 50, 150) + # 查找轮廓 + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 初始化包围所有外接矩形的大矩形的坐标 + x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1 + # 遍历所有外接矩形,更新大矩形的坐标 + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + x_min = min(x_min, x) + y_min = min(y_min, y) + x_max = max(x_max, x + w) + y_max = max(y_max, y + h) + + # 根据大矩形的坐标来裁剪原始图像 + result_image = image[y_min:y_max, x_min:x_max] + # cv2.imshow("", cropped_image) + # cv2.waitKey(0) + + return result_image