feat generate 逻辑调整

This commit is contained in:
zhouchengrong
2024-04-16 16:36:17 +08:00
parent b596692b35
commit 6759b873d5
2 changed files with 24 additions and 7 deletions

View File

@@ -19,7 +19,7 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
DEBUG = False DEBUG = True
if DEBUG: if DEBUG:
LOGS_PATH = "logs/" LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
@@ -29,7 +29,7 @@ else:
# RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-dev" # 开发环境
RABBITMQ_ENV = "-local" # 本地测试环境 RABBITMQ_ENV = "-local" # 本地测试环境
settings = Settings() settings = Settings()

View File

@@ -10,7 +10,10 @@
import json import json
import logging import logging
import time import time
from io import BytesIO
import cv2
import minio
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
import numpy as np import numpy as np
@@ -20,7 +23,6 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger() logger = logging.getLogger()
@@ -32,13 +34,15 @@ class GenerateImage:
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
if request_data.mode == "txt2img": if request_data.mode == "img2img":
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else: else:
self.image = request_data.image self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.prompt = request_data.prompt
self.mode = request_data.mode self.mode = request_data.mode
self.batch_size = 1 self.batch_size = 1
self.category = request_data.category self.category = request_data.category
@@ -49,9 +53,22 @@ class GenerateImage:
self.grpc_client.close() self.grpc_client.close()
self.connection.close() self.connection.close()
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
except minio.error.S3Error:
image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image_cv2
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data) self.redis_client.set(self.tasks_id, self.generate_data)
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error): def callback(self, result, error):
if error: if error: