feat 产品图打光模型部署

fix
This commit is contained in:
zhouchengrong
2024-06-19 16:44:04 +08:00
parent 8476bb3727
commit d04c3857fc
3 changed files with 41 additions and 81 deletions

View File

@@ -124,7 +124,7 @@ GPI_MODEL_URL = '10.1.1.240:10061'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'stable_diffusion_1_5'
GRI_MODEL_NAME = 'diffusion_relight_ensemble'
GRI_MODEL_URL = '10.1.1.150:8001'
# SEG service config

View File

@@ -20,7 +20,7 @@ from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
@@ -166,10 +166,11 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateImageModel(
rd = GenerateProductImageModel(
tasks_id="123-89",
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
prompt="",
# prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -38,9 +38,10 @@ class GenerateRelightImage:
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "12345"
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = "Right Light"
# TODO aida design 结果图背景改为白色
# self.image, self.image_size = self.get_image(request_data.image_url)
self.image = request_data.image_url
self.image = self.get_image(request_data.image_url)
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
@@ -51,37 +52,8 @@ class GenerateRelightImage:
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1)
return image
def callback(self, result, error):
if error:
@@ -92,7 +64,7 @@ class GenerateRelightImage:
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
@@ -114,47 +86,33 @@ class GenerateRelightImage:
def get_result(self):
try:
direction = "Right Light"
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput(
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
)
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
seeds = [self.seed] * self.batch_size
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
negative_prompts = [negative_prompt] * self.batch_size
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput(
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
)
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput(
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
)
input_seed.set_data_from_numpy(seed)
input_images = [self.image] * self.batch_size
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
input_input_images = grpcclient.InferInput(
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
)
input_input_images.set_data_from_numpy(text_obj_images)
directions = [direction] * self.batch_size
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
input_directions = grpcclient.InferInput(
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
)
input_directions.set_data_from_numpy(text_obj_directions)
output_img = grpcclient.InferRequestedOutput("generated_image")
request_start = time.time()
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
ctx = self.infer(inputs)
time_out = 600
@@ -179,9 +137,9 @@ class GenerateRelightImage:
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
@@ -195,8 +153,9 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
image_url="/workspace/i3.png",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="",
image_url='aida-users/89/product_image/123-89.png'
)
server = GenerateRelightImage(rd)
print(server.get_result())