Files
AiDA_Python/app/service/generate_image/service_generate_relight_image.py
2024-06-20 16:23:02 +08:00

144 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "1"
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = "Right Light"
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
seeds = [self.seed] * self.batch_size
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-users/89/product_image/123-89.png'
)
server = GenerateRelightImage(rd)
print(server.get_result())