From 756894baff6dbc9671951b9cfcc0fbe39482d90d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 10:45:45 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 4 + app/schemas/generate_image.py | 6 + .../design/items/pipelines/painting.py | 121 ++++++++++- app/service/design/items/pipelines/scale.py | 2 +- .../service_generate_relight_image.py | 202 ++++++++++++++++++ 5 files changed, 326 insertions(+), 9 deletions(-) create mode 100644 app/service/generate_image/service_generate_relight_image.py diff --git a/app/core/config.py b/app/core/config.py index 30bdd30..651dd8b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -122,6 +122,10 @@ GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProduct GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' +# Generate Single Logo service config +GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_MODEL_URL = '10.1.1.150:8001' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 49cf9ce..4f85002 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,3 +20,9 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + + +class GenerateRelightImageModel(BaseModel): + tasks_id: str + prompt: str + image_url: str diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index aa310a3..43b42e4 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -152,16 +152,14 @@ class PrintPainting(object): rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i]) rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i]) - source_image_pil = Image.fromarray(print_background) - source_image_pil_mask = Image.fromarray(mask_background) + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source) source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask) - print_background = np.array(source_image_pil) - mask_background = np.array(source_image_pil_mask) - - # print(1) + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) @@ -241,7 +239,6 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - return result else: painting_dict = {} painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2] @@ -260,7 +257,113 @@ class PrintPainting(object): temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) result['single_image'] = cv2.add(tmp1, tmp2) - return result + + if "element" in result.keys(): + print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(result['element']['element_path_list'])): + image, image_mode = self.read_image(result['element']['element_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + print(1) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + # TODO element 丢失信息 + three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) + img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) + # mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + # gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + # img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result @staticmethod def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): @@ -301,6 +404,7 @@ class PrintPainting(object): return painting_dict def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + tile = None if not trigger: tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) else: @@ -351,6 +455,7 @@ class PrintPainting(object): print_mask = result['mask'] img_fg = result['final_image'] if print_ and not painting_dict['Trigger']: + index_ = None try: index_ = len(painting_dict['location']) except: diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py index 6e0cf87..80009e1 100644 --- a/app/service/design/items/pipelines/scale.py +++ b/app/service/design/items/pipelines/scale.py @@ -25,7 +25,7 @@ class Scaling(object): # # distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1) if distance_clo == 0: - result['scale'] = 10 + result['scale'] = 1 else: result['scale'] = distance_bdy / distance_clo elif result['keypoint'] == 'toe': diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py new file mode 100644 index 0000000..0eacec9 --- /dev/null +++ b/app/service/generate_image/service_generate_relight_image.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import io +import json +import logging +import time +import cv2 +import redis +import tritonclient.grpc as grpcclient +import numpy as np +from PIL import Image, ImageOps +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateRelightImageModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image + +logger = logging.getLogger() + + +class GenerateRelightImage: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "relight_image" + self.batch_size = 1 + self.prompt = request_data.prompt + self.seed = "12345" + # TODO aida design 结果图背景改为白色 + # self.image, self.image_size = self.get_image(request_data.image_url) + self.image = request_data.image_url + # TODO image 填充并resize成512*768 + + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_bytes = io.BytesIO(response.read()) + + # 转换为PIL图像对象 + image = Image.open(image_bytes) + target_height = 768 + target_width = 512 + + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + # image = cv2.resize(image_rbg, (1024, 1024)) + return image, image_size + + def callback(self, result, error): + if error: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + # self.gen_product_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + else: + # pil图像转成numpy数组 + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def infer(self, inputs): + return self.grpc_client.async_infer( + model_name=GRI_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) + + def get_result(self): + try: + direction = "Right Light" + negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' + prompts = [self.prompt] * self.batch_size + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + input_text = grpcclient.InferInput( + "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) + ) + input_text.set_data_from_numpy(text_obj) + + negative_prompts = [negative_prompt] * self.batch_size + text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) + input_text_neg = grpcclient.InferInput( + "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) + ) + input_text_neg.set_data_from_numpy(text_obj_neg) + + seed = np.array(self.seed, dtype="object").reshape((-1, 1)) + input_seed = grpcclient.InferInput( + "seed", seed.shape, np_to_triton_dtype(seed.dtype) + ) + input_seed.set_data_from_numpy(seed) + + input_images = [self.image] * self.batch_size + text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) + input_input_images = grpcclient.InferInput( + "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) + ) + input_input_images.set_data_from_numpy(text_obj_images) + + directions = [direction] * self.batch_size + text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) + input_directions = grpcclient.InferInput( + "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) + ) + input_directions.set_data_from_numpy(text_obj_directions) + + output_img = grpcclient.InferRequestedOutput("generated_image") + request_start = time.time() + + inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] + + ctx = self.infer(inputs) + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + # logger.info(gen_product_data) + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + # logger.info(time_out, gen_product_data) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) + return data + + +if __name__ == '__main__': + rd = GenerateRelightImageModel( + tasks_id="123-89", + prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + image_url="/workspace/i3.png", + ) + server = GenerateRelightImage(rd) + print(server.get_result())