diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index 0c8bca3..45346bf 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -22,25 +22,19 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.adjust_contrast import adjust_contrast -from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd - logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): - if DEBUG is False: - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() if request_data.mode == "img2img": - # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) self.prompt = request_data.prompt else: @@ -53,69 +47,36 @@ class GenerateImage: self.batch_size = 1 self.category = request_data.category self.index = 0 - self.gender = request_data.gender - self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. # Read data from response. - # read image use cv2 try: response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image_file = BytesIO(response.data) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: - image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - return image + image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image_cv2 def callback(self, result, error): if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) - # self.generate_data['data'] = str(error) + self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: - # pil图像转成numpy数组 - image = result.as_numpy("generated_image") - image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) - is_smudge = True - if self.category == "sketch": - # 色阶调整 - cutoff = 1 - levels_img = autoLevels(image_result, cutoff) - # 亮度调整 - luminance = luminance_adjust(0.3, levels_img) - # 去背景 - remove_bg_image = remove_background(luminance) - # 人脸检测 - if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: - is_smudge = False - else: - # 污点/ - is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) - # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) - self.generate_data['category'] = str(category) - image_result = not_smudge_image - if is_smudge: # 无污点 - # image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - else: # 有污点 保存图片到本地 测试用 - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - # logger.info(f"stain_detection result : {self.generate_data}") + image_result = result.as_numpy("generated_image")[0] + image_result = adjust_contrast(image_result) + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['data'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) @@ -148,11 +109,10 @@ class GenerateImage: inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) - time_out = 600 + time_out = 60 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() - # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -160,18 +120,16 @@ class GenerateImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, generate_data) return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" - self.generate_data['message'] = str(e) + self.generate_data['message'] = "failure" + self.generate_data['data'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() - if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") diff --git a/app/service/generate_image/utils/adjust_contrast.py b/app/service/generate_image/utils/adjust_contrast.py index 2af8969..4ed110e 100644 --- a/app/service/generate_image/utils/adjust_contrast.py +++ b/app/service/generate_image/utils/adjust_contrast.py @@ -1,13 +1,14 @@ import cv2 - def adjust_contrast(image, alpha=1.5, beta=-60): """ 调整图像的对比度和亮度。 + 参数: image_path (numpy): 图像的路径。 alpha (float): 控制对比度的系数。alpha > 1 增加对比度,alpha < 1 减少对比度。 beta (int): 用于调整亮度的值,可以是正或负。 + 返回: adjusted_image (ndarray): 调整对比度后的图像。 """ @@ -18,13 +19,13 @@ def adjust_contrast(image, alpha=1.5, beta=-60): # 使用示例 if __name__ == "__main__": - image = cv2.imread('output_6.png') # 替换为你的图片路径 + image = cv2.imread('output_6.png') # 替换为你的图片路径 img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - alpha = 1.5 # 对比度系数,大于1增加对比度 + alpha = 1.5 # 对比度系数,大于1增加对比度 beta = -60 # 亮度调整,这里设置为0,不改变亮度 # 调整图像对比度 result_image = adjust_contrast(image, alpha, beta) - # 可以选择保存调整后的图像 - cv2.imwrite('adjusted_image.jpg', result_image) # 保存调整后的图片 + # 可以选择保存调整后的图像 + cv2.imwrite('adjusted_image.jpg', result_image) # 保存调整后的图片 \ No newline at end of file