From 4e55275e6e592f242616c1d62b76701bca2e473b Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:04:40 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20generate=20product=20relight=20pose=5Ftransform?= =?UTF-8?q?=20=E5=BC=80=E5=8F=91=EF=BC=8C=E8=AE=BE=E7=BD=AEbatch=20generat?= =?UTF-8?q?e=20=E7=9A=84=E4=BC=98=E5=85=88=E7=BA=A7=E4=B8=BA100=20?= =?UTF-8?q?=EF=BC=8Csingle=20generate=20=E7=9A=84=E4=BC=98=E5=85=88?= =?UTF-8?q?=E7=BA=A7=E4=B8=BA1=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug?= =?UTF-8?q?=EF=BC=89:=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 22 +- app/core/config.py | 3 + app/schemas/generate_image.py | 23 +++ app/schemas/pose_transform.py | 7 + app/service/generate_batch_image/service.py | 24 +++ .../service_batch_generate_product_image.py | 191 ++++++++++++++++++ .../service_batch_generate_relight_image.py | 162 +++++++++++++++ .../service_batch_pose_transform.py | 176 ++++++++++++++++ .../generate_image/service_generate_image.py | 4 +- .../service_generate_product_image.py | 4 +- .../service_generate_relight_image.py | 4 +- app/service/utils/redis_utils.py | 99 +++++++++ 12 files changed, 712 insertions(+), 7 deletions(-) create mode 100644 app/service/generate_batch_image/service.py create mode 100644 app/service/generate_batch_image/service_batch_generate_product_image.py create mode 100644 app/service/generate_batch_image/service_batch_generate_relight_image.py create mode 100644 app/service/generate_batch_image/service_batch_pose_transform.py create mode 100644 app/service/utils/redis_utils.py diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index a37bec3..f151b91 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,8 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel +from app.schemas.pose_transform import BatchPoseTransformModel from app.schemas.response_template import ResponseModel +from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel @@ -228,3 +230,21 @@ def generate_relight_image(tasks_id: str): logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data['data']) + + +"""batch generate img""" + + +@router.post("/batch_generate_product_image") +async def design(request_batch_item: BatchGenerateProductImageModel): + return await start_product_batch_generate(request_batch_item) + + +@router.post("/batch_generate_relight_image") +async def design(request_batch_item: BatchGenerateRelightImageModel): + return await start_relight_batch_generate(request_batch_item) + + +@router.post("/batch_generate_pose_transform_image") +async def design(request_batch_item: BatchPoseTransformModel): + return await start_pose_transform_batch_generate(request_batch_item) diff --git a/app/core/config.py b/app/core/config.py index ac9181f..aaf32d7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -135,12 +135,14 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Product service config 旧版product img 模型 GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") +BATCH_GPI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"BatchToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +BATCH_GRI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"BatchRelight{RABBITMQ_ENV}") GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' @@ -148,6 +150,7 @@ GRI_MODEL_URL = '10.1.1.240:10051' # Pose Transform service config PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") +BATCH_PS_RABBITMQ_QUEUES = os.getenv("BATCH_PS_RABBITMQ_QUEUES", f"BatchPoseTransform{RABBITMQ_ENV}") PT_MODEL_URL = '10.1.1.243:10061' # SEG service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 7181418..99d1836 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -36,3 +36,26 @@ class GenerateRelightImageModel(BaseModel): image_url: str direction: str product_type: str + + +""" + batch generate image +""" + + +class BatchGenerateProductImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str + image_strength: float + product_type: str + batch_size: int + + +class BatchGenerateRelightImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str + direction: str + product_type: str + batch_size: int diff --git a/app/schemas/pose_transform.py b/app/schemas/pose_transform.py index 045d8b9..22526ff 100644 --- a/app/schemas/pose_transform.py +++ b/app/schemas/pose_transform.py @@ -5,3 +5,10 @@ class PoseTransformModel(BaseModel): image_url: str tasks_id: str pose_id: str + + +class BatchPoseTransformModel(BaseModel): + image_url: str + tasks_id: str + pose_id: str + batch_size: int diff --git a/app/service/generate_batch_image/service.py b/app/service/generate_batch_image/service.py new file mode 100644 index 0000000..2279382 --- /dev/null +++ b/app/service/generate_batch_image/service.py @@ -0,0 +1,24 @@ +from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status +from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status +from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status + + +async def start_product_batch_generate(data): + generate_clothes_task = batch_generate_product.delay(data.dict()) + print(generate_clothes_task) + product_publish_status(data.tasks_id, f"0/{data.batch_size}", "") + return {"task_id": data.tasks_id, "state": generate_clothes_task.state} + + +async def start_relight_batch_generate(data): + generate_clothes_task = batch_generate_relight.delay(data.dict()) + print(generate_clothes_task) + relight_publish_status(data.tasks_id, f"0/{data.batch_size}", "") + return {"task_id": data.tasks_id, "state": generate_clothes_task.state} + + +async def start_pose_transform_batch_generate(data): + generate_clothes_task = batch_generate_pose_transform.delay(data.dict()) + print(generate_clothes_task) + pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "") + return {"task_id": data.tasks_id, "state": generate_clothes_task.state} diff --git a/app/service/generate_batch_image/service_batch_generate_product_image.py b/app/service/generate_batch_image/service_batch_generate_product_image.py new file mode 100644 index 0000000..438ec99 --- /dev/null +++ b/app/service/generate_batch_image/service_batch_generate_product_image.py @@ -0,0 +1,191 @@ +# 旧版product +# !/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from PIL import Image +from celery import Celery +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import BatchGenerateProductImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image + +celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True) +celery_app.conf.task_default_queue = 'queue_product' +celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' +celery_app.conf.worker_hijack_root_logger = False +logger = logging.getLogger() +logging.getLogger('pika').setLevel(logging.WARNING) +grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) +category = "product_image" + + +@celery_app.task +def batch_generate_product(batch_request_data): + logger.info(f"batch_generate_product batch_request_data:{batch_request_data}") + tasks_id = batch_request_data['tasks_id'] + user_id = tasks_id.rsplit('-', 1)[1] + batch_size = batch_request_data['batch_size'] + image = pre_processing_image(batch_request_data['image_url']) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + images = [image.astype(np.uint8)] * 1 + + prompts = [batch_request_data['prompt']] * 1 + + if batch_request_data['product_type'] == "single": + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_image_strength.set_data_from_numpy(image_strength_obj) + + inputs = [input_text, input_image, input_image_strength] + + image_url_list = [] + for i in range(batch_size): + try: + if batch_request_data['product_type'] == "single": + result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) + image = result.as_numpy("generated_cnet_image") + else: + result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100) + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + except Exception as e: + if 'mask_list' in str(e): + e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) + + e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype)) + e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8") + e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype)) + + e_input_text.set_data_from_numpy(e_text_obj) + e_input_image.set_data_from_numpy(e_image_obj) + e_input_image_strength.set_data_from_numpy(e_image_strength_obj) + + result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100) + image = result.as_numpy("generated_cnet_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + else: + image_result = str(e) + logger.error(image_result) + + if isinstance(image_result, Image.Image): + image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") + image_url_list.append(image_url) + else: + image_url = image_result + if DEBUG is False: + if i + 1 < batch_size: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}") + print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}") + else: + publish_status(tasks_id, f"OK", image_url_list) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}") + print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}") + + +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 目标图片的尺寸 + target_width = 512 + target_height = 768 + + # 原始图片的尺寸 + original_width, original_height = image.size + + # 计算宽度和高度的缩放比例 + width_ratio = target_width / original_width + height_ratio = target_height / original_height + + # 选择较小的缩放比例,确保图片能完整放入目标图片中 + scale_ratio = min(width_ratio, height_ratio) + + # 计算调整后的尺寸 + new_width = int(original_width * scale_ratio) + new_height = int(original_height * scale_ratio) + + # 调整图片大小 + resized_image = image.resize((new_width, new_height)) + + # 创建一个 512x768 的透明图片 + result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0)) + + # 计算需要粘贴的位置,使图片居中 + x_offset = (target_width - new_width) // 2 + y_offset = (target_height - new_height) // 2 + + # 将调整大小后的图片粘贴到透明图片上 + if resized_image.mode == "RGBA": + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) + else: + result_image.paste(resized_image, (x_offset, y_offset)) + + image = np.array(result_image) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def post_processing_image(image, left, top): + resized_image = image.resize((int(image.width * (768 / image.height)), 768)) + # 计算裁剪的坐标 + left = (resized_image.width - 512) // 2 + upper = 0 + right = left + 512 + lower = 768 + + # 进行裁剪 + cropped_image = resized_image.crop((left, upper, right, lower)) + return cropped_image + + +def publish_status(task_id, progress, result): + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + channel = connection.channel() + channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True) + message = {'task_id': task_id, 'progress': progress, "result": result} + channel.basic_publish(exchange='', + routing_key=BATCH_GPI_RABBITMQ_QUEUES, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() + + +if __name__ == '__main__': + rd = BatchGenerateProductImageModel( + tasks_id="123-15-51-89", + image_strength=0.7, + prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + product_type="overall", + batch_size=20 + ) + batch_generate_product(rd.dict()) diff --git a/app/service/generate_batch_image/service_batch_generate_relight_image.py b/app/service/generate_batch_image/service_batch_generate_relight_image.py new file mode 100644 index 0000000..fa53f26 --- /dev/null +++ b/app/service/generate_batch_image/service_batch_generate_relight_image.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from PIL import Image +from celery import Celery +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import BatchGenerateRelightImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() +celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True) +celery_app.conf.task_default_queue = 'queue_relight' +celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' +celery_app.conf.worker_hijack_root_logger = False +logging.getLogger('pika').setLevel(logging.WARNING) +grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) +category = "relight_image" + + +@celery_app.task +def batch_generate_relight(batch_request_data): + logger.info(f"batch_generate_relight batch_request_data: {batch_request_data}") + negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + direction = batch_request_data['direction'] + seed = "1" + prompt = batch_request_data['prompt'] + product_type = batch_request_data['product_type'] + image_url = batch_request_data['image_url'] + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2") + tasks_id = batch_request_data['tasks_id'] + user_id = tasks_id.rsplit('-', 1)[1] + batch_size = batch_request_data['batch_size'] + + prompts = [prompt] * 1 + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * 1 + seeds = [seed] * 1 + nagetive_prompts = [negative_prompt] * 1 + directions = [direction] * 1 + + if product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) + + inputs = [input_text, input_natext, input_image, input_seed, input_direction] + image_url_list = [] + for i in range(batch_size): + try: + if batch_request_data['product_type'] == "single": + result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) + image = result.as_numpy("generated_relight_image") + else: + result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100) + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + + except Exception as e: + print(e) + if 'mask_list' in str(e): + e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + + e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype)) + e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8") + e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype)) + e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype)) + e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype)) + + e_input_text.set_data_from_numpy(e_text_obj) + e_input_image.set_data_from_numpy(e_image_obj) + e_input_natext.set_data_from_numpy(e_na_text_obj) + e_input_seed.set_data_from_numpy(e_seed_obj) + e_input_direction.set_data_from_numpy(e_direction_obj) + + e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction] + + result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100) + image = result.as_numpy("generated_relight_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + else: + image_result = str(e) + logger.error(e) + if isinstance(image_result, Image.Image): + image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") + image_url_list.append(image_url) + else: + image_url = image_result + if DEBUG is False: + if i + 1 < batch_size: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}") + print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}") + else: + publish_status(tasks_id, f"OK", image_url_list) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}") + print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}") + + +def publish_status(task_id, progress, result): + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + channel = connection.channel() + channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True) + message = {'task_id': task_id, 'progress': progress, "result": result} + channel.basic_publish(exchange='', + routing_key=BATCH_GRI_RABBITMQ_QUEUES, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() + + +if __name__ == '__main__': + rd = BatchGenerateRelightImageModel( + tasks_id="123-89", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="Colorful black", + image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png', + direction="Right Light", + product_type="overall", + batch_size=10 + ) + batch_generate_relight(rd.dict()) diff --git a/app/service/generate_batch_image/service_batch_pose_transform.py b/app/service/generate_batch_image/service_batch_pose_transform.py new file mode 100644 index 0000000..3507a43 --- /dev/null +++ b/app/service/generate_batch_image/service_batch_pose_transform.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import io +import json +import logging +from io import BytesIO + +import imageio +import numpy as np +import tritonclient.grpc as grpcclient +from PIL import Image +from celery import Celery +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.pose_transform import BatchPoseTransformModel +from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video +from app.service.utils.new_oss_client import oss_upload_image +from app.service.utils.oss_client import oss_get_image + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + +logger = logging.getLogger() +celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True) +celery_app.conf.task_default_queue = 'queue_post_transform' +celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' +celery_app.conf.worker_hijack_root_logger = False +logging.getLogger('pika').setLevel(logging.WARNING) +grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL) +category = "pose_transform" + + +def upload_first_image(image, user_id, category, file_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 目标图片的尺寸 + target_width = 512 + target_height = 768 + + # 原始图片的尺寸 + original_width, original_height = image.size + + # 计算宽度和高度的缩放比例 + width_ratio = target_width / original_width + height_ratio = target_height / original_height + + # 选择较小的缩放比例,确保图片能完整放入目标图片中 + scale_ratio = min(width_ratio, height_ratio) + + # 计算调整后的尺寸 + new_width = int(original_width * scale_ratio) + new_height = int(original_height * scale_ratio) + + # 调整图片大小 + resized_image = image.resize((new_width, new_height)) + + # 创建一个 512x768 的透明图片 + result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0)) + + # 计算需要粘贴的位置,使图片居中 + x_offset = (target_width - new_width) // 2 + y_offset = (target_height - new_height) // 2 + + # 将调整大小后的图片粘贴到透明图片上 + if resized_image.mode == "RGBA": + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) + else: + result_image.paste(resized_image, (x_offset, y_offset)) + result_image = result_image.convert("RGB") + image = np.array(result_image) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + return image + + +@celery_app.task +def batch_generate_pose_transform(batch_request_data): + logger.info(f"batch_generate_pose_transform batch_request_data: {batch_request_data}") + batch_size = batch_request_data['batch_size'] + image_url = batch_request_data['image_url'] + image = pre_processing_image(image_url) + pose_num = batch_request_data['pose_id'] + tasks_id = batch_request_data['tasks_id'] + user_id = tasks_id.rsplit('-', 1)[1] + + pose_num = [pose_num] * 1 + pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1)) + input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype)) + input_pose_num.set_data_from_numpy(pose_num_obj) + + image_files = [image.astype(np.uint8)] * 1 + image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3)) + input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8") + input_image_files.set_data_from_numpy(image_files_obj) + + result_url_list = [] + for i in range(batch_size): + try: + result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100) + result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1] + # 第一帧图像 + first_image = Image.fromarray(result_data[0]) + first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png") + + # 上传GIF + gif_buffer = BytesIO() + imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5) + gif_buffer.seek(0) + gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif") + + # 上传video + video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4") + data = { + "gif_url": gif_url, + "video_url": video_url, + "first_image_url": first_image_url, + } + except Exception as e: + print(e) + data = {} + result_url_list.append(data) + if DEBUG is False: + if i + 1 < batch_size: + publish_status(tasks_id, f"{i + 1}/{batch_size}", data) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{data}") + print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{data}") + else: + publish_status(tasks_id, f"OK", result_url_list) + logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{result_url_list}") + print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{result_url_list}") + + +def publish_status(task_id, progress, result): + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + channel = connection.channel() + channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True) + message = {'task_id': task_id, 'progress': progress, "result": result} + channel.basic_publish(exchange='', + routing_key=BATCH_GRI_RABBITMQ_QUEUES, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() + + +if __name__ == '__main__': + rd = BatchPoseTransformModel( + tasks_id="123-89", + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + pose_id="1", + batch_size=10 + ) + batch_generate_pose_transform(rd.dict()) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 86912f8..4ed8fd4 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -153,9 +153,9 @@ class GenerateImage: inputs = [input_text, input_image, input_mode] if self.version == "fast": - ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1) else: - ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1) time_out = 600 generate_data = None diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 235f366..d0fbe74 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -295,9 +295,9 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1) else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1) time_out = 600 while time_out > 0: diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 2e48ae2..668e7fd 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -114,9 +114,9 @@ class GenerateRelightImage: inputs = [input_text, input_natext, input_image, input_seed, input_direction] if self.product_type == 'single': - ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1) else: - ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1) time_out = 600 while time_out > 0: diff --git a/app/service/utils/redis_utils.py b/app/service/utils/redis_utils.py new file mode 100644 index 0000000..012fbe0 --- /dev/null +++ b/app/service/utils/redis_utils.py @@ -0,0 +1,99 @@ +import redis + +from app.core.config import REDIS_HOST, REDIS_PORT + + +class Redis(object): + """ + redis数据库操作 + """ + + @staticmethod + def _get_r(): + host = REDIS_HOST + port = REDIS_PORT + db = 0 + r = redis.StrictRedis(host, port, db) + return r + + @classmethod + def write(cls, key, value, expire=None): + """ + 写入键值对 + """ + # 判断是否有过期时间,没有就设置默认值 + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.set(key, value, ex=expire_in_seconds) + + @classmethod + def read(cls, key): + """ + 读取键值对内容 + """ + r = cls._get_r() + value = r.get(key) + return value.decode('utf-8') if value else value + + @classmethod + def hset(cls, name, key, value): + """ + 写入hash表 + """ + r = cls._get_r() + r.hset(name, key, value) + + @classmethod + def hget(cls, name, key): + """ + 读取指定hash表的键值 + """ + r = cls._get_r() + value = r.hget(name, key) + return value.decode('utf-8') if value else value + + @classmethod + def hgetall(cls, name): + """ + 获取指定hash表所有的值 + """ + r = cls._get_r() + return r.hgetall(name) + + @classmethod + def delete(cls, *names): + """ + 删除一个或者多个 + """ + r = cls._get_r() + r.delete(*names) + + @classmethod + def hdel(cls, name, key): + """ + 删除指定hash表的键值 + """ + r = cls._get_r() + r.hdel(name, key) + + @classmethod + def expire(cls, name, expire=None): + """ + 设置过期时间 + """ + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.expire(name, expire_in_seconds) + + +if __name__ == '__main__': + redis_client = Redis() + # print(redis_client.write(key="1230", value=0)) + redis_client.write(key="1230", value=10) + # print(redis_client.read(key="1230"))