feat(新功能):

fix(修复bug):  图片生成服务优化,避免mq连接超时
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-06-24 16:58:05 +08:00
parent c77540678b
commit 6203dde267
13 changed files with 1259 additions and 167 deletions

View File

@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -29,12 +30,6 @@ logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.version = request_data.version
if request_data.version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
@@ -161,7 +156,6 @@ class GenerateImage:
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -169,7 +163,6 @@ class GenerateImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
@@ -178,11 +171,8 @@ class GenerateImage:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
self.connection.close()
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):