feat 更新响应模板

fix
This commit is contained in:
zhouchengrong
2024-06-17 10:45:45 +08:00
parent b3081359b7
commit 756894baff
5 changed files with 326 additions and 9 deletions

View File

@@ -122,6 +122,10 @@ GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProduct
GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_URL = '10.1.1.240:10061'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'stable_diffusion_1_5'
GRI_MODEL_URL = '10.1.1.150:8001'
# SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'

View File

@@ -20,3 +20,9 @@ class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str

View File

@@ -152,16 +152,14 @@ class PrintPainting(object):
rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i])
source_image_pil = Image.fromarray(print_background)
source_image_pil_mask = Image.fromarray(mask_background)
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask)
print_background = np.array(source_image_pil)
mask_background = np.array(source_image_pil_mask)
# print(1)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
@@ -241,7 +239,6 @@ class PrintPainting(object):
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
else:
painting_dict = {}
painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2]
@@ -260,6 +257,112 @@ class PrintPainting(object):
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if "element" in result.keys():
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(result['element']['element_path_list'])):
image, image_mode = self.read_image(result['element']['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
print(1)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -301,6 +404,7 @@ class PrintPainting(object):
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
@@ -351,6 +455,7 @@ class PrintPainting(object):
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:

View File

@@ -25,7 +25,7 @@ class Scaling(object):
#
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
if distance_clo == 0:
result['scale'] = 10
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':

View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "12345"
# TODO aida design 结果图背景改为白色
# self.image, self.image_size = self.get_image(request_data.image_url)
self.image = request_data.image_url
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GRI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
direction = "Right Light"
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput(
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
)
input_text.set_data_from_numpy(text_obj)
negative_prompts = [negative_prompt] * self.batch_size
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput(
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
)
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput(
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
)
input_seed.set_data_from_numpy(seed)
input_images = [self.image] * self.batch_size
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
input_input_images = grpcclient.InferInput(
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
)
input_input_images.set_data_from_numpy(text_obj_images)
directions = [direction] * self.batch_size
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
input_directions = grpcclient.InferInput(
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
)
input_directions.set_data_from_numpy(text_obj_directions)
output_img = grpcclient.InferRequestedOutput("generated_image")
request_start = time.time()
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
ctx = self.infer(inputs)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
image_url="/workspace/i3.png",
)
server = GenerateRelightImage(rd)
print(server.get_result())