feat(新功能):
fix(修复bug): 图片生成服务优化,避免mq连接超时 docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -1,19 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_image.agent_generate import GenerateImage
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.get("/agent_generate_image")
|
||||
def generate_image(prompt: str):
|
||||
server = GenerateImage()
|
||||
byte_stream = server.get_result(prompt)
|
||||
# 返回流式响应
|
||||
return StreamingResponse(byte_stream, media_type="image/png")
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_agent_generate_image
|
||||
from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
@@ -34,7 +33,6 @@ router.include_router(api_query_image.router, tags=['api_query_image'], prefix="
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -14,10 +14,12 @@ router = APIRouter()
|
||||
@router.get("{id}")
|
||||
def test(id: int):
|
||||
data = {
|
||||
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||
"local_oss_server": OSS
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class GenerateImage:
|
||||
def __init__(self):
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.batch_size = 1
|
||||
self.mode = 'txt2img'
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
def get_result(self, prompt):
|
||||
prompts = [prompt] * self.batch_size
|
||||
modes = [self.mode] * self.batch_size
|
||||
images = [self.image.astype(np.float16)] * self.batch_size
|
||||
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
||||
byte_stream = io.BytesIO(img_byte_array)
|
||||
byte_stream.seek(0)
|
||||
|
||||
# object_name = f'test.jpg'
|
||||
# req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array)
|
||||
# url = self.minio_client.get_presigned_url(
|
||||
# "GET",
|
||||
# "test",
|
||||
# object_name,
|
||||
# expires=timedelta(hours=2),
|
||||
# )
|
||||
return byte_stream
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = GenerateImage()
|
||||
print(server.get_result("rabbit"))
|
||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateImageModel
|
||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -29,12 +30,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.version = request_data.version
|
||||
if request_data.version == "fast":
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||
@@ -161,7 +156,6 @@ class GenerateImage:
|
||||
generate_data = None
|
||||
while time_out > 0:
|
||||
generate_data, _ = self.read_tasks_status()
|
||||
# logger.info(generate_data)
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
@@ -169,7 +163,6 @@ class GenerateImage:
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
# logger.info(time_out, generate_data)
|
||||
return generate_data
|
||||
except Exception as e:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
@@ -178,11 +171,8 @@ class GenerateImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
self.connection.close()
|
||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -17,6 +17,7 @@ import tritonclient.grpc as grpcclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateMultiViewModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -25,14 +26,7 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateMultiView:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.image = self.get_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
@@ -52,16 +46,11 @@ class GenerateMultiView:
|
||||
if error:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
self.generate_data['message'] = str(error)
|
||||
# self.generate_data['data'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
images = result.as_numpy("generated_image")
|
||||
# for id, img in enumerate(images):
|
||||
# cv2.imwrite(f"{id}.png", img)
|
||||
# image_url = ""
|
||||
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||
self.generate_data['status'] = "SUCCESS"
|
||||
self.generate_data['message'] = "success"
|
||||
self.generate_data['image_url'] = str(image_url)
|
||||
@@ -103,11 +92,8 @@ class GenerateMultiView:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
self.connection.close()
|
||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GMV_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -212,6 +212,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateProductImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -220,12 +221,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateProductImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "product_image"
|
||||
@@ -318,10 +313,8 @@ class GenerateProductImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
self.connection.close()
|
||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_gen_product_data, GPI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -20,6 +20,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateRelightImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateRelightImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "relight_image"
|
||||
@@ -137,10 +134,8 @@ class GenerateRelightImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
self.connection.close()
|
||||
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
import tritonclient.grpc as grpcclient
|
||||
from app.schemas.generate_image import GenerateSingleLogoImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateSingleLogoImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.batch_size = 1
|
||||
@@ -96,10 +93,8 @@ class GenerateSingleLogoImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
self.connection.close()
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pose_transform import PoseTransformModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -114,23 +115,12 @@ class PoseTransformService:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
publish_status(str_pose_transform_data)
|
||||
if not DEBUG:
|
||||
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
||||
|
||||
|
||||
def publish_status(message):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=PS_RABBITMQ_QUEUES, durable=True)
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=PS_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
23
app/service/generate_image/utils/mq.py
Normal file
23
app/service/generate_image/utils/mq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import json
|
||||
|
||||
import pika
|
||||
import logging
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def publish_status(message, queue_name):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=queue_name, durable=True)
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=queue_name,
|
||||
body=message,
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
logger.info(f" [x] Queue : {queue_name} | Sent message : {json.dumps(json.loads(message), indent=4)}")
|
||||
Reference in New Issue
Block a user