Files
AiDA_Python/app/service/generate_image/service_generate_single_logo.py
zchengrong 6203dde267 feat(新功能):
fix(修复bug):  图片生成服务优化,避免mq连接超时
docs(文档变更):
refactor(重构):
test(增加测试):
2025-06-24 16:58:05 +08:00

116 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
self.category = "single_logo"
self.negative_prompts = "bad, ugly"
self.seed = request_data.seed
self.tasks_id = request_data.tasks_id
self.prompt = request_data.prompt
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
self.redis_client.expire(self.tasks_id, 600)
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def callback(self, result, error):
if error:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
else:
image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
def get_result(self):
try:
# prompt
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
input_seed.set_data_from_numpy(seed)
inputs = [input_text, input_text_neg, input_seed]
ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateSingleLogoImageModel(
tasks_id="123-89",
prompt='an apple',
seed="2",
)
server = GenerateSingleLogoImage(rd)
print(server.get_result())