feat 产品图打光模型部署
fix
This commit is contained in:
@@ -124,7 +124,7 @@ GPI_MODEL_URL = '10.1.1.240:10061'
|
|||||||
|
|
||||||
# Generate Single Logo service config
|
# Generate Single Logo service config
|
||||||
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
||||||
GRI_MODEL_NAME = 'stable_diffusion_1_5'
|
GRI_MODEL_NAME = 'diffusion_relight_ensemble'
|
||||||
GRI_MODEL_URL = '10.1.1.150:8001'
|
GRI_MODEL_URL = '10.1.1.150:8001'
|
||||||
|
|
||||||
# SEG service config
|
# SEG service config
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from minio import Minio
|
|||||||
from tritonclient.utils import np_to_triton_dtype
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateImageModel
|
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -166,10 +166,11 @@ def infer_cancel(tasks_id):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
rd = GenerateImageModel(
|
rd = GenerateProductImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
prompt="",
|
||||||
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
|
# prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||||
|
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||||
)
|
)
|
||||||
server = GenerateProductImage(rd)
|
server = GenerateProductImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
@@ -38,9 +38,10 @@ class GenerateRelightImage:
|
|||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.prompt = request_data.prompt
|
self.prompt = request_data.prompt
|
||||||
self.seed = "12345"
|
self.seed = "12345"
|
||||||
|
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||||
|
self.direction = "Right Light"
|
||||||
# TODO aida design 结果图背景改为白色
|
# TODO aida design 结果图背景改为白色
|
||||||
# self.image, self.image_size = self.get_image(request_data.image_url)
|
self.image = self.get_image(request_data.image_url)
|
||||||
self.image = request_data.image_url
|
|
||||||
# TODO image 填充并resize成512*768
|
# TODO image 填充并resize成512*768
|
||||||
|
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
@@ -51,37 +52,8 @@ class GenerateRelightImage:
|
|||||||
|
|
||||||
def get_image(self, image_url):
|
def get_image(self, image_url):
|
||||||
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
||||||
image_bytes = io.BytesIO(response.read())
|
image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1)
|
||||||
|
return image
|
||||||
# 转换为PIL图像对象
|
|
||||||
image = Image.open(image_bytes)
|
|
||||||
target_height = 768
|
|
||||||
target_width = 512
|
|
||||||
|
|
||||||
aspect_ratio = image.width / image.height
|
|
||||||
new_width = int(target_height * aspect_ratio)
|
|
||||||
|
|
||||||
resized_image = image.resize((new_width, target_height))
|
|
||||||
left = (target_width - resized_image.width) // 2
|
|
||||||
top = (target_height - resized_image.height) // 2
|
|
||||||
right = target_width - resized_image.width - left
|
|
||||||
bottom = target_height - resized_image.height - top
|
|
||||||
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
|
|
||||||
image_size = image.size
|
|
||||||
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
|
|
||||||
# 创建白色背景
|
|
||||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
|
||||||
# 将图片粘贴到白色背景上
|
|
||||||
background.paste(image, mask=image.split()[3])
|
|
||||||
image = np.array(background)
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# image_file = BytesIO(response.data)
|
|
||||||
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
|
||||||
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
|
||||||
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
|
||||||
# image = cv2.resize(image_rbg, (1024, 1024))
|
|
||||||
return image, image_size
|
|
||||||
|
|
||||||
def callback(self, result, error):
|
def callback(self, result, error):
|
||||||
if error:
|
if error:
|
||||||
@@ -92,7 +64,7 @@ class GenerateRelightImage:
|
|||||||
else:
|
else:
|
||||||
# pil图像转成numpy数组
|
# pil图像转成numpy数组
|
||||||
image = result.as_numpy("generated_inpaint_image")
|
image = result.as_numpy("generated_inpaint_image")
|
||||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||||
|
|
||||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||||
@@ -114,47 +86,33 @@ class GenerateRelightImage:
|
|||||||
|
|
||||||
def get_result(self):
|
def get_result(self):
|
||||||
try:
|
try:
|
||||||
direction = "Right Light"
|
|
||||||
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
|
||||||
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
|
|
||||||
prompts = [self.prompt] * self.batch_size
|
prompts = [self.prompt] * self.batch_size
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||||
input_text = grpcclient.InferInput(
|
image = cv2.resize(image, (512, 768))
|
||||||
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
|
images = [image.astype(np.uint8)] * self.batch_size
|
||||||
)
|
seeds = [self.seed] * self.batch_size
|
||||||
|
nagetive_prompts = [self.negative_prompt] * self.batch_size
|
||||||
|
directions = [self.direction] * self.batch_size
|
||||||
|
|
||||||
|
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||||
|
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||||
|
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||||
|
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||||
|
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||||
|
|
||||||
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||||
|
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||||
|
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||||
|
|
||||||
input_text.set_data_from_numpy(text_obj)
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
input_natext.set_data_from_numpy(na_text_obj)
|
||||||
|
input_seed.set_data_from_numpy(seed_obj)
|
||||||
|
input_direction.set_data_from_numpy(direction_obj)
|
||||||
|
|
||||||
negative_prompts = [negative_prompt] * self.batch_size
|
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||||
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
|
|
||||||
input_text_neg = grpcclient.InferInput(
|
|
||||||
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
|
|
||||||
)
|
|
||||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
|
||||||
|
|
||||||
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
|
|
||||||
input_seed = grpcclient.InferInput(
|
|
||||||
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
|
|
||||||
)
|
|
||||||
input_seed.set_data_from_numpy(seed)
|
|
||||||
|
|
||||||
input_images = [self.image] * self.batch_size
|
|
||||||
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
|
|
||||||
input_input_images = grpcclient.InferInput(
|
|
||||||
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
|
|
||||||
)
|
|
||||||
input_input_images.set_data_from_numpy(text_obj_images)
|
|
||||||
|
|
||||||
directions = [direction] * self.batch_size
|
|
||||||
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
|
|
||||||
input_directions = grpcclient.InferInput(
|
|
||||||
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
|
|
||||||
)
|
|
||||||
input_directions.set_data_from_numpy(text_obj_directions)
|
|
||||||
|
|
||||||
output_img = grpcclient.InferRequestedOutput("generated_image")
|
|
||||||
request_start = time.time()
|
|
||||||
|
|
||||||
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
|
|
||||||
|
|
||||||
ctx = self.infer(inputs)
|
ctx = self.infer(inputs)
|
||||||
time_out = 600
|
time_out = 600
|
||||||
@@ -179,9 +137,9 @@ class GenerateRelightImage:
|
|||||||
finally:
|
finally:
|
||||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
|
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
@@ -195,8 +153,9 @@ def infer_cancel(tasks_id):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
rd = GenerateRelightImageModel(
|
rd = GenerateRelightImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||||
image_url="/workspace/i3.png",
|
prompt="",
|
||||||
|
image_url='aida-users/89/product_image/123-89.png'
|
||||||
)
|
)
|
||||||
server = GenerateRelightImage(rd)
|
server = GenerateRelightImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
Reference in New Issue
Block a user