feat generate slogan | to product image | slogan 接口部署

This commit is contained in:
zhouchengrong
2024-05-30 15:02:35 +08:00
parent 401b76bd95
commit 5092a8c7bc
2 changed files with 253 additions and 0 deletions

View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
# if DEBUG is False:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
# read image use cv2
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image = cv2.resize(image_rbg, (1024, 1024))
except minio.error.S3Error:
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
is_smudge = True
if self.category == "sketch":
# 色阶调整
cutoff = 1
levels_img = autoLevels(image_result, cutoff)
# 亮度调整
luminance = luminance_adjust(0.3, levels_img)
# 去背景
remove_bg_image = remove_background(luminance)
# 人脸检测
if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
is_smudge = False
else:
# 污点/
is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id)
# 类型识别
category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
if is_smudge: # 无污点
# image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点 保存图片到本地 测试用
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
# prompts = [self.prompt] * self.batch_size
# modes = [self.mode] * self.batch_size
# images = [self.image.astype(np.float16)] * self.batch_size
#
# text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
# mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
# image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
#
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
# input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
#
# input_text.set_data_from_numpy(text_obj)
# input_image.set_data_from_numpy(image_obj)
# input_mode.set_data_from_numpy(mode_obj)
#
# inputs = [input_text, input_image, input_mode]
# ctx = self.infer(inputs)
# time_out = 600
# generate_data = None
# while time_out > 0:
# generate_data, _ = self.read_tasks_status()
# # logger.info(generate_data)
# if generate_data['status'] in ["REVOKED", "FAILURE"]:
# ctx.cancel()
# break
# elif generate_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
# # logger.info(time_out, generate_data)
generate_data, _ = self.read_tasks_status()
return generate_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
)
server = GenerateImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import redis
from minio import Minio
from app.core.config import *
from app.schemas.generate_image import GenerateSingleLogoImageModel
logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
# if DEBUG is False:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
self.redis_client.expire(self.tasks_id, 600)
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
generate_data, _ = self.read_tasks_status()
return generate_data
except Exception as e:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateSingleLogoImageModel(
tasks_id="123-8",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
)
server = GenerateSingleLogoImage(rd)
print(server.get_result())