feat(新功能):
fix(修复bug): 图片生成服务优化,避免mq连接超时 docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -1,19 +0,0 @@
|
|||||||
import io
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
from app.schemas.response_template import ResponseModel
|
|
||||||
from app.service.generate_image.agent_generate import GenerateImage
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/agent_generate_image")
|
|
||||||
def generate_image(prompt: str):
|
|
||||||
server = GenerateImage()
|
|
||||||
byte_stream = server.get_result(prompt)
|
|
||||||
# 返回流式响应
|
|
||||||
return StreamingResponse(byte_stream, media_type="image/png")
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api import api_agent_generate_image
|
|
||||||
from app.api import api_attribute_retrieve, api_query_image
|
from app.api import api_attribute_retrieve, api_query_image
|
||||||
from app.api import api_brand_dna
|
from app.api import api_brand_dna
|
||||||
from app.api import api_brighten
|
from app.api import api_brighten
|
||||||
@@ -34,7 +33,6 @@ router.include_router(api_query_image.router, tags=['api_query_image'], prefix="
|
|||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||||
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
|
||||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
|
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -14,10 +14,12 @@ router = APIRouter()
|
|||||||
@router.get("{id}")
|
@router.get("{id}")
|
||||||
def test(id: int):
|
def test(id: int):
|
||||||
data = {
|
data = {
|
||||||
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||||
|
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||||
|
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||||
"local_oss_server": OSS
|
"local_oss_server": OSS
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: UTF-8 -*-
|
|
||||||
"""
|
|
||||||
@Project :trinity_client
|
|
||||||
@File :service_att_recognition.py
|
|
||||||
@Author :周成融
|
|
||||||
@Date :2023/7/26 12:01:05
|
|
||||||
@detail :
|
|
||||||
"""
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import tritonclient.grpc as grpcclient
|
|
||||||
from minio import Minio
|
|
||||||
from tritonclient.utils import np_to_triton_dtype
|
|
||||||
|
|
||||||
from app.core.config import *
|
|
||||||
from app.service.utils.oss_client import oss_upload_image
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateImage:
|
|
||||||
def __init__(self):
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
|
||||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
|
||||||
self.batch_size = 1
|
|
||||||
self.mode = 'txt2img'
|
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
|
|
||||||
def get_result(self, prompt):
|
|
||||||
prompts = [prompt] * self.batch_size
|
|
||||||
modes = [self.mode] * self.batch_size
|
|
||||||
images = [self.image.astype(np.float16)] * self.batch_size
|
|
||||||
|
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
|
||||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
|
||||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
|
||||||
|
|
||||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
|
||||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
|
||||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
|
||||||
|
|
||||||
input_text.set_data_from_numpy(text_obj)
|
|
||||||
input_image.set_data_from_numpy(image_obj)
|
|
||||||
input_mode.set_data_from_numpy(mode_obj)
|
|
||||||
|
|
||||||
inputs = [input_text, input_image, input_mode]
|
|
||||||
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
|
||||||
image = result.as_numpy("generated_image")
|
|
||||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
|
||||||
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
|
||||||
byte_stream = io.BytesIO(img_byte_array)
|
|
||||||
byte_stream.seek(0)
|
|
||||||
|
|
||||||
# object_name = f'test.jpg'
|
|
||||||
# req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array)
|
|
||||||
# url = self.minio_client.get_presigned_url(
|
|
||||||
# "GET",
|
|
||||||
# "test",
|
|
||||||
# object_name,
|
|
||||||
# expires=timedelta(hours=2),
|
|
||||||
# )
|
|
||||||
return byte_stream
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
server = GenerateImage()
|
|
||||||
print(server.get_result("rabbit"))
|
|
||||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateImageModel
|
from app.schemas.generate_image import GenerateImageModel
|
||||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
|
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -29,12 +30,6 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
class GenerateImage:
|
class GenerateImage:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
if DEBUG is False:
|
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
self.channel = self.connection.channel()
|
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
# self.channel = self.connection.channel()
|
|
||||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
self.version = request_data.version
|
self.version = request_data.version
|
||||||
if request_data.version == "fast":
|
if request_data.version == "fast":
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||||
@@ -161,7 +156,6 @@ class GenerateImage:
|
|||||||
generate_data = None
|
generate_data = None
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
generate_data, _ = self.read_tasks_status()
|
generate_data, _ = self.read_tasks_status()
|
||||||
# logger.info(generate_data)
|
|
||||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
ctx.cancel()
|
ctx.cancel()
|
||||||
break
|
break
|
||||||
@@ -169,7 +163,6 @@ class GenerateImage:
|
|||||||
break
|
break
|
||||||
time_out -= 1
|
time_out -= 1
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
# logger.info(time_out, generate_data)
|
|
||||||
return generate_data
|
return generate_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.generate_data['status'] = "FAILURE"
|
self.generate_data['status'] = "FAILURE"
|
||||||
@@ -178,11 +171,8 @@ class GenerateImage:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||||
self.connection.close()
|
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
|
||||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import tritonclient.grpc as grpcclient
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateMultiViewModel
|
from app.schemas.generate_image import GenerateMultiViewModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -25,14 +26,7 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
class GenerateMultiView:
|
class GenerateMultiView:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
if DEBUG is False:
|
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
self.channel = self.connection.channel()
|
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
# self.channel = self.connection.channel()
|
|
||||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||||
|
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.image = self.get_image(request_data.image_url)
|
self.image = self.get_image(request_data.image_url)
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
@@ -52,16 +46,11 @@ class GenerateMultiView:
|
|||||||
if error:
|
if error:
|
||||||
self.generate_data['status'] = "FAILURE"
|
self.generate_data['status'] = "FAILURE"
|
||||||
self.generate_data['message'] = str(error)
|
self.generate_data['message'] = str(error)
|
||||||
# self.generate_data['data'] = str(error)
|
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
else:
|
else:
|
||||||
# pil图像转成numpy数组
|
# pil图像转成numpy数组
|
||||||
images = result.as_numpy("generated_image")
|
images = result.as_numpy("generated_image")
|
||||||
# for id, img in enumerate(images):
|
|
||||||
# cv2.imwrite(f"{id}.png", img)
|
|
||||||
# image_url = ""
|
|
||||||
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
|
||||||
self.generate_data['status'] = "SUCCESS"
|
self.generate_data['status'] = "SUCCESS"
|
||||||
self.generate_data['message'] = "success"
|
self.generate_data['message'] = "success"
|
||||||
self.generate_data['image_url'] = str(image_url)
|
self.generate_data['image_url'] = str(image_url)
|
||||||
@@ -103,11 +92,8 @@ class GenerateMultiView:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
publish_status(str_generate_data, GMV_RABBITMQ_QUEUES)
|
||||||
self.connection.close()
|
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
|
||||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateProductImageModel
|
from app.schemas.generate_image import GenerateProductImageModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -220,12 +221,6 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
class GenerateProductImage:
|
class GenerateProductImage:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
if DEBUG is False:
|
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
self.channel = self.connection.channel()
|
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
# self.channel = self.connection.channel()
|
|
||||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "product_image"
|
self.category = "product_image"
|
||||||
@@ -318,10 +313,8 @@ class GenerateProductImage:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
publish_status(str_gen_product_data, GPI_RABBITMQ_QUEUES)
|
||||||
self.connection.close()
|
|
||||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateRelightImageModel
|
from app.schemas.generate_image import GenerateRelightImageModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
class GenerateRelightImage:
|
class GenerateRelightImage:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
if DEBUG is False:
|
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
self.channel = self.connection.channel()
|
|
||||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "relight_image"
|
self.category = "relight_image"
|
||||||
@@ -137,10 +134,8 @@ class GenerateRelightImage:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
|
||||||
self.connection.close()
|
|
||||||
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
from app.schemas.generate_image import GenerateSingleLogoImageModel
|
from app.schemas.generate_image import GenerateSingleLogoImageModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
class GenerateSingleLogoImage:
|
class GenerateSingleLogoImage:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
if DEBUG is False:
|
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
self.channel = self.connection.channel()
|
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
@@ -96,10 +93,8 @@ class GenerateSingleLogoImage:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||||
self.connection.close()
|
|
||||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.pose_transform import PoseTransformModel
|
from app.schemas.pose_transform import PoseTransformModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
|
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -114,23 +115,12 @@ class PoseTransformService:
|
|||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
|
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if not DEBUG:
|
||||||
publish_status(str_pose_transform_data)
|
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||||
logger.info(
|
logger.info(
|
||||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
def publish_status(message):
|
|
||||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
channel = connection.channel()
|
|
||||||
channel.queue_declare(queue=PS_RABBITMQ_QUEUES, durable=True)
|
|
||||||
channel.basic_publish(exchange='',
|
|
||||||
routing_key=PS_RABBITMQ_QUEUES,
|
|
||||||
body=json.dumps(message),
|
|
||||||
properties=pika.BasicProperties(
|
|
||||||
delivery_mode=2,
|
|
||||||
))
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
|
|||||||
23
app/service/generate_image/utils/mq.py
Normal file
23
app/service/generate_image/utils/mq.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pika
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.core.config import RABBITMQ_PARAMS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def publish_status(message, queue_name):
|
||||||
|
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
channel = connection.channel()
|
||||||
|
channel.queue_declare(queue=queue_name, durable=True)
|
||||||
|
channel.basic_publish(exchange='',
|
||||||
|
routing_key=queue_name,
|
||||||
|
body=message,
|
||||||
|
properties=pika.BasicProperties(
|
||||||
|
delivery_mode=2,
|
||||||
|
))
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
logger.info(f" [x] Queue : {queue_name} | Sent message : {json.dumps(json.loads(message), indent=4)}")
|
||||||
18
pyproject.toml
Executable file
18
pyproject.toml
Executable file
@@ -0,0 +1,18 @@
|
|||||||
|
[project]
|
||||||
|
name = "trinity-client-aida"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"apscheduler>=3.11.0",
|
||||||
|
"celery>=5.5.3",
|
||||||
|
"geventhttpclient>=2.3.4",
|
||||||
|
"google-search-results>=2.4.2",
|
||||||
|
"moviepy>=2.2.1",
|
||||||
|
"numpy==1.26.4",
|
||||||
|
"pandas-stubs==2.2.3.250527",
|
||||||
|
"pika-stubs==0.1.3",
|
||||||
|
"python-multipart>=0.0.20",
|
||||||
|
"tritonclient[all]>=2.58.0",
|
||||||
|
"types-urllib3==1.26.25.14",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user