Files
AiDA_Python/app/service/generate_image/service_generate_relight_image.py
zcr 18024a2d70
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
2025-12-30 16:49:08 +08:00

202 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import settings, GRI_MODEL_URL, GRI_MODEL_NAME_SINGLE, GRI_MODEL_NAME_OVERALL, GRI_RABBITMQ_QUEUES
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "1"
self.product_type = request_data.product_type
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = request_data.direction
self.image_url = request_data.image_url
self.image = pre_processing_image(self.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
if 'mask_list' in str(error):
self.gen_product_data['status'] = "NO_FACE"
self.gen_product_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
if self.product_type == 'single':
image = result.as_numpy("generated_relight_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
seeds = [self.seed] * self.batch_size
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
if self.product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
if gen_product_data['status'] in ["REVOKED", "FAILURE", "NO_FACE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if not settings.DEBUG:
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
direction="Right Light",
product_type="overall"
)
server = GenerateRelightImage(rd)
print(server.get_result())