diff --git a/Dockerfile b/Dockerfile index c577312..0bd3e74 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,7 @@ RUN apt install -y libgl1-mesa-glx COPY ./requirements.txt /requirements.txt RUN pip install --upgrade pip RUN pip install -r requirements.txt +RUN mkdir -p app/logs RUN pip install gunicorn RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 RUN #pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html @@ -19,4 +20,4 @@ LABEL maintainer="zchengrong@yeah.net" \ name="trinity_aida" -CMD ["gunicorn", "-c", "gunicorn_config.py", "app.main:app" , "-e", "SR_RABBITMQ_QUEUES=SuperResolution-dev" ,"-e", "GI_RABBITMQ_QUEUES=GenerateImage-dev"] \ No newline at end of file +CMD ["gunicorn", "-c", "gunicorn_config.py", "app.main:app" , "-e", "SR_RABBITMQ_QUEUES=SuperResolution-local" ,"-e", "GI_RABBITMQ_QUEUES=GenerateImage-local"] \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index 23ff405..4889db8 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -28,8 +28,8 @@ else: CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" # RABBITMQ_ENV = "" # 生产环境 -RABBITMQ_ENV = "-dev" # 开发环境 -# RABBITMQ_ENV = "-local" # 本地测试环境 +# RABBITMQ_ENV = "-dev" # 开发环境 +RABBITMQ_ENV = "-local" # 本地测试环境 settings = Settings() diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index d307474..da71c16 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -30,9 +30,6 @@ class AttributeRecognition: self.const = const self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") - def __del__(self): - self.triton_client.close() - def get_result(self): for sketch in self.request_data: if sketch['category'] == "Tops" or sketch['category'] == "Blouse": diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index fed3d41..46c96a9 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -10,7 +10,10 @@ import json import logging import time +from io import BytesIO +import cv2 +import minio import redis import tritonclient.grpc as grpcclient import numpy as np @@ -20,7 +23,6 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd -from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() @@ -32,26 +34,36 @@ class GenerateImage: self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - if request_data.mode == "txt2img": - self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + if request_data.mode == "img2img": + self.image = self.get_image(request_data.image_url) + self.prompt = request_data.prompt else: - self.image = request_data.image + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.prompt = request_data.prompt + self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] - self.prompt = request_data.prompt self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category self.index = 0 - def __del__(self): - self.redis_client.close() - self.grpc_client.close() - self.connection.close() + def get_image(self, image_url): + # Get data of an object. + # Read data from response. + try: + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_file = BytesIO(response.data) + image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + except minio.error.S3Error: + image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image_cv2 def __call__(self, *args, **kwargs): self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.redis_client.set(self.tasks_id, self.generate_data) + self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: diff --git a/app/service/generate_image/test.py b/app/service/generate_image/test.py index 0e03900..ab2dc43 100644 --- a/app/service/generate_image/test.py +++ b/app/service/generate_image/test.py @@ -64,11 +64,6 @@ class GenerateImage: pass - def __del__(self): - self.redis_client.close() - self.triton_client.close() - self.connection.close() - @staticmethod def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index e20eb70..95b2811 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -29,11 +29,6 @@ class SuperResolution: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - def __del__(self): - self.redis_client.close() - self.triton_client.close() - self.connection.close() - # @RunTime def read_image(self): try: diff --git a/requirements.txt b/requirements.txt index b77dd7d..1529082 100644 Binary files a/requirements.txt and b/requirements.txt differ