diff --git a/app/core/config.py b/app/core/config.py index a114fa6..73e6e67 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -29,7 +29,7 @@ else: # RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "-dev" # 开发环境 -RABBITMQ_ENV = "-local" # 本地测试环境 +RABBITMQ_ENV = "-local" # 本地测试环境 settings = Settings() diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index fed3d41..b1448b2 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -10,7 +10,10 @@ import json import logging import time +from io import BytesIO +import cv2 +import minio import redis import tritonclient.grpc as grpcclient import numpy as np @@ -20,7 +23,6 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd -from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() @@ -32,13 +34,15 @@ class GenerateImage: self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - if request_data.mode == "txt2img": - self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + if request_data.mode == "img2img": + self.image = self.get_image(request_data.image_url) + self.prompt = request_data.prompt else: - self.image = request_data.image + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.prompt = request_data.prompt + self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] - self.prompt = request_data.prompt self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category @@ -49,9 +53,22 @@ class GenerateImage: self.grpc_client.close() self.connection.close() + def get_image(self, image_url): + # Get data of an object. + # Read data from response. + try: + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_file = BytesIO(response.data) + image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + except minio.error.S3Error: + image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image_cv2 + def __call__(self, *args, **kwargs): self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.redis_client.set(self.tasks_id, self.generate_data) + self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: