feat generate product image 部署

This commit is contained in:
zhouchengrong
2024-06-04 15:33:34 +08:00
parent 40a2e158e2
commit 1d94d485e9
5 changed files with 108 additions and 114 deletions

View File

@@ -19,7 +19,7 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
DEBUG = False DEBUG = True
if DEBUG: if DEBUG:
LOGS_PATH = "logs/" LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
@@ -119,6 +119,9 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f
# Generate Single Logo service config # Generate Single Logo service config
GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}") GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_URL = '10.1.1.240:10061'
# SEG service config # SEG service config
SEG_MODEL_URL = '10.1.1.240:10000' SEG_MODEL_URL = '10.1.1.240:10000'

View File

@@ -5,9 +5,6 @@ class GenerateImageModel(BaseModel):
tasks_id: str tasks_id: str
prompt: str prompt: str
image_url: str image_url: str
mode: str
category: str
gender: str
class GenerateSingleLogoImageModel(BaseModel): class GenerateSingleLogoImageModel(BaseModel):

View File

@@ -7,38 +7,42 @@
@Date 2023/7/26 12:01:05 @Date 2023/7/26 12:01:05
@detail @detail
""" """
import io
import json import json
import logging import logging
import time import time
from io import BytesIO
import cv2 import cv2
import minio
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
import numpy as np import numpy as np
from PIL import Image, ImageOps
from minio import Minio from minio import Minio
from tritonclient.utils import np_to_triton_dtype from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
logger = logging.getLogger() logger = logging.getLogger()
class GenerateProductImage: class GenerateProductImage:
def __init__(self, request_data): def __init__(self, request_data):
# if DEBUG is False: if DEBUG is False:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel() self.channel = self.connection.channel()
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() # self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.batch_size = 1
self.prompt = request_data.prompt
# TODO aida design 结果图背景改为白色
self.image, self.image_size = self.get_image(request_data.image_url)
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -46,63 +50,56 @@ class GenerateProductImage:
self.redis_client.expire(self.tasks_id, 600) self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url): def get_image(self, image_url):
# Get data of an object. response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
# Read data from response. image_bytes = io.BytesIO(response.read())
# read image use cv2
try: # 转换为PIL图像对象
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image = Image.open(image_bytes)
image_file = BytesIO(response.data) target_height = 768
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) target_width = 512
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) aspect_ratio = image.width / image.height
image = cv2.resize(image_rbg, (1024, 1024)) new_width = int(target_height * aspect_ratio)
except minio.error.S3Error:
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) resized_image = image.resize((new_width, target_height))
return image left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error): def callback(self, result, error):
if error: if error:
self.generate_data['status'] = "FAILURE" self.gen_product_data['status'] = "FAILURE"
self.generate_data['message'] = str(error) self.gen_product_data['message'] = str(error)
# self.generate_data['data'] = str(error) # self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else: else:
# pil图像转成numpy数组 # pil图像转成numpy数组
image = result.as_numpy("generated_image") image = result.as_numpy("generated_inpaint_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
is_smudge = True
if self.category == "sketch": image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# 色阶调整 # logger.info(f"upload image SUCCESS {image_url}")
cutoff = 1 self.gen_product_data['status'] = "SUCCESS"
levels_img = autoLevels(image_result, cutoff) self.gen_product_data['message'] = "success"
# 亮度调整 self.gen_product_data['image_url'] = str(image_url)
luminance = luminance_adjust(0.3, levels_img) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# 去背景
remove_bg_image = remove_background(luminance)
# 人脸检测
if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
is_smudge = False
else:
# 污点/
is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id)
# 类型识别
category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
if is_smudge: # 无污点
# image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点 保存图片到本地 测试用
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self): def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id) status_data = self.redis_client.get(self.tasks_id)
@@ -110,46 +107,43 @@ class GenerateProductImage:
def infer(self, inputs): def infer(self, inputs):
return self.grpc_client.async_infer( return self.grpc_client.async_infer(
model_name=GI_MODEL_NAME, model_name=GPI_MODEL_NAME,
inputs=inputs, inputs=inputs,
callback=self.callback callback=self.callback
) )
def get_result(self): def get_result(self):
try: try:
# prompts = [self.prompt] * self.batch_size prompts = [self.prompt] * self.batch_size
# modes = [self.mode] * self.batch_size self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
# images = [self.image.astype(np.float16)] * self.batch_size self.image = cv2.resize(self.image, (512, 768))
# images = [self.image.astype(np.uint8)] * self.batch_size
# text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
# mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) text_obj = np.array(prompts, dtype="object").reshape(1)
# image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
#
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
# input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_text.set_data_from_numpy(text_obj)
# input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj)
# input_image.set_data_from_numpy(image_obj) inputs = [input_text, input_image]
# input_mode.set_data_from_numpy(mode_obj)
# ctx = self.infer(inputs)
# inputs = [input_text, input_image, input_mode] time_out = 600
# ctx = self.infer(inputs) while time_out > 0:
# time_out = 600 gen_product_data, _ = self.read_tasks_status()
# generate_data = None # logger.info(gen_product_data)
# while time_out > 0: if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
# generate_data, _ = self.read_tasks_status() ctx.cancel()
# # logger.info(generate_data) break
# if generate_data['status'] in ["REVOKED", "FAILURE"]: elif gen_product_data['status'] == "SUCCESS":
# ctx.cancel() break
# break time_out -= 1
# elif generate_data['status'] == "SUCCESS": time.sleep(0.1)
# break # logger.info(time_out, gen_product_data)
# time_out -= 1 gen_product_data, _ = self.read_tasks_status()
# time.sleep(0.1) return gen_product_data
# # logger.info(time_out, generate_data)
generate_data, _ = self.read_tasks_status()
return generate_data
except Exception as e: except Exception as e:
self.gen_product_data['status'] = "FAILURE" self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e) self.gen_product_data['message'] = str(e)
@@ -157,25 +151,25 @@ class GenerateProductImage:
raise Exception(str(e)) raise Exception(str(e))
finally: finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status() dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
# if DEBUG is False: if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_gen_product_data)
self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id): def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data) gen_product_data = json.dumps(data)
redis_client.set(tasks_id, generate_data) redis_client.set(tasks_id, gen_product_data)
return data return data
if __name__ == '__main__': if __name__ == '__main__':
rd = GenerateImageModel( rd = GenerateImageModel(
tasks_id="123-89", tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="", image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
) )
server = GenerateImage(rd) server = GenerateProductImage(rd)
print(server.get_result()) print(server.get_result())

View File

@@ -21,7 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger() logger = logging.getLogger()
@@ -67,7 +67,7 @@ class GenerateSingleLogoImage:
else: else:
image = result.as_numpy("generated_image") image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_single_logo(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url) self.gen_single_logo_data['image_url'] = str(image_url)
@@ -131,7 +131,7 @@ if __name__ == '__main__':
rd = GenerateSingleLogoImageModel( rd = GenerateSingleLogoImageModel(
tasks_id="123-89", tasks_id="123-89",
prompt='an apple', prompt='an apple',
seed="1", seed="2",
) )
server = GenerateSingleLogoImage(rd) server = GenerateSingleLogoImage(rd)
print(server.get_result()) print(server.get_result())

View File

@@ -34,7 +34,7 @@ s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S
# except Exception as e: # except Exception as e:
# print(f'上传到 S3 失败: {e}') # print(f'上传到 S3 失败: {e}')
def upload_single_logo(image, user_id, category, object_name): def upload_SDXL_image(image, user_id, category, object_name):
try: try:
image_data = io.BytesIO() image_data = io.BytesIO()
image.save(image_data, format='PNG') image.save(image_data, format='PNG')