feat 新增 生成sketch时对图片清理背景,剔除带有污点的结果图

This commit is contained in:
zchen
2024-04-23 14:59:47 +08:00
parent ae52608951
commit 528b332677
4 changed files with 66 additions and 20 deletions

View File

@@ -19,7 +19,7 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
DEBUG = False DEBUG = True
if DEBUG: if DEBUG:
LOGS_PATH = "logs/" LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
@@ -66,6 +66,7 @@ GI_MODEL_NAME = 'stable_diffusion_xl_lcm'
GI_MODEL_URL = '10.1.1.150:8001' GI_MODEL_URL = '10.1.1.150:8001'
GI_MINIO_BUCKET = "aida-users" GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SEG service config # SEG service config
SEG_MODEL_URL = '10.1.1.240:10000' SEG_MODEL_URL = '10.1.1.240:10000'

View File

@@ -22,7 +22,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.remove_background import remove_background from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
logger = logging.getLogger() logger = logging.getLogger()
@@ -30,11 +30,14 @@ logger = logging.getLogger()
class GenerateImage: class GenerateImage:
def __init__(self, request_data): def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
if request_data.mode == "img2img": if request_data.mode == "img2img":
self.image = self.get_image(request_data.image_url) self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt self.prompt = request_data.prompt
@@ -60,9 +63,10 @@ class GenerateImage:
image_file = BytesIO(response.data) image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = cv2.resize(image_cv2, (1024, 1024))
except minio.error.S3Error: except minio.error.S3Error:
image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image_cv2 return image
def callback(self, result, error): def callback(self, result, error):
if error: if error:
@@ -73,12 +77,22 @@ class GenerateImage:
else: else:
image_result = result.as_numpy("generated_image")[0] image_result = result.as_numpy("generated_image")[0]
if self.category == "sketch": if self.category == "sketch":
image_result = remove_background(np.asarray(image_result)) # 去背景
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") remove_bg_image = remove_background(np.asarray(image_result))
self.generate_data['status'] = "SUCCESS" # 污点检测
self.generate_data['message'] = "success" is_smudge, not_smudge_image = stain_detection(remove_bg_image)
self.generate_data['data'] = str(image_url) if is_smudge is False:
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = GI_SYS_IMAGE_URL
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
image_result = not_smudge_image
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
def read_tasks_status(self): def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id) status_data = self.redis_client.get(self.tasks_id)
@@ -131,7 +145,8 @@ class GenerateImage:
raise Exception(str(e)) raise Exception(str(e))
finally: finally:
dict_generate_data, str_generate_data = self.read_tasks_status() dict_generate_data, str_generate_data = self.read_tasks_status()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")

View File

@@ -22,7 +22,7 @@ from PIL import Image
import time import time
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.remove_background import remove_background from app.service.generate_image.utils.image_processing import remove_background
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid from app.service.utils.generate_uuid import generate_uuid

View File

@@ -1,11 +1,15 @@
import logging
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
from PIL import Image
import tritonclient.http as httpclient import tritonclient.http as httpclient
import torch.nn.functional as F import torch.nn.functional as F
from app.core.config import * from app.core.config import *
import cv2
logger = logging.getLogger()
def seg_preprocess(img_path): def seg_preprocess(img_path):
@@ -107,11 +111,15 @@ def remove_background(image):
result_mask = front_mask + back_mask result_mask = front_mask + back_mask
white_background = np.ones_like(image_obj) * 255 white_background = np.ones_like(image_obj) * 255
result_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background) remove_bg_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background)
# cv2.imwrite("source_image", image)
# cv2.imwrite("remove_bg_image", remove_bg_image)
import cv2 return remove_bg_image
edges = cv2.Canny(result_image, 50, 150)
def bounding_box(image):
edges = cv2.Canny(image, 50, 150)
# 查找轮廓 # 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标 # 初始化包围所有外接矩形的大矩形的坐标
@@ -126,7 +134,29 @@ def remove_background(image):
# 根据大矩形的坐标来裁剪原始图像 # 根据大矩形的坐标来裁剪原始图像
result_image = image[y_min:y_max, x_min:x_max] result_image = image[y_min:y_max, x_min:x_max]
# cv2.imshow("", cropped_image) # cv2.imshow("result_image", result_image)
# cv2.waitKey(0) # cv2.waitKey(0)
return result_image return result_image
def stain_detection(image, spot_size=200):
height, width, _ = image.shape
corners = [
image[0:spot_size, 0:spot_size], # top left
image[0:spot_size, width - spot_size:width], # top right
image[height - spot_size:height, 0:spot_size], # bottom left
image[height - spot_size:height, width - spot_size:width] # bottom right
]
for index, corner in enumerate(corners):
num_white_pixels = (corner == [255, 255, 255]).all(axis=2).sum()
if num_white_pixels != spot_size * spot_size:
logger.info(f"{index + 1}发现了污点")
return False, None
if DEBUG:
for corner_coords in [(0, 0), (0, width - spot_size), (height - spot_size, 0), (height - spot_size, width - spot_size)]:
cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
return True, image