Merge branch 'local' into develop

This commit is contained in:
zhouchengrong
2024-06-17 10:46:09 +08:00
72 changed files with 5667 additions and 45 deletions

View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.batch_size = 1
self.prompt = request_data.prompt
# TODO aida design 结果图背景改为白色
self.image, self.image_size = self.get_image(request_data.image_url)
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GPI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image]
ctx = self.infer(inputs)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "12345"
# TODO aida design 结果图背景改为白色
# self.image, self.image_size = self.get_image(request_data.image_url)
self.image = request_data.image_url
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GRI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
direction = "Right Light"
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput(
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
)
input_text.set_data_from_numpy(text_obj)
negative_prompts = [negative_prompt] * self.batch_size
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput(
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
)
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput(
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
)
input_seed.set_data_from_numpy(seed)
input_images = [self.image] * self.batch_size
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
input_input_images = grpcclient.InferInput(
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
)
input_input_images.set_data_from_numpy(text_obj_images)
directions = [direction] * self.batch_size
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
input_directions = grpcclient.InferInput(
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
)
input_directions.set_data_from_numpy(text_obj_directions)
output_img = grpcclient.InferRequestedOutput("generated_image")
request_start = time.time()
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
ctx = self.infer(inputs)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
image_url="/workspace/i3.png",
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
self.category = "single_logo"
self.negative_prompts = "bad, ugly"
self.seed = request_data.seed
self.tasks_id = request_data.tasks_id
self.prompt = request_data.prompt
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
self.redis_client.expire(self.tasks_id, 600)
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GSL_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def callback(self, result, error):
if error:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
else:
image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
def get_result(self):
try:
# prompt
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
# negative_prompts
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
# print('text obj neg: ', text_obj_neg)
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
# seed
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
input_seed.set_data_from_numpy(seed)
inputs = [input_text, input_text_neg, input_seed]
ctx = self.infer(inputs)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateSingleLogoImageModel(
tasks_id="123-89",
prompt='an apple',
seed="2",
)
server = GenerateSingleLogoImage(rd)
print(server.get_result())

View File

@@ -382,7 +382,7 @@ if __name__ == '__main__':
remove_bg_img = remove_background(luminance)
# cv2.imwrite("remove_bg_img.png", remove_bg_img)
print(1)
# print(1)
cv2.imshow("source", img)
cv2.imshow("levels", equAuto)
cv2.imshow("luminance", luminance)

View File

@@ -10,6 +10,7 @@
import io
import logging
import boto3
import cv2
from PIL import Image
from minio import Minio
@@ -17,6 +18,39 @@ from minio import Minio
from app.core.config import *
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
# def upload_single_logo(image, user_id, category, object_name):
# with io.BytesIO() as output:
# image.save(output, format='PNG')
# data = output.getvalue()
# # 创建一个 S3 客户端
# try:
# key = f'{user_id}/{category}/{object_name}'
# image_url = f"{AIDA_CLOTHING}/{key}"
# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png')
# return image_url
# except Exception as e:
# print(f'上传到 S3 失败: {e}')
def upload_SDXL_image(image, user_id, category, object_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
io.BytesIO(image_bytes),
len(image_bytes),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_png_sd(image, user_id, category, object_name):