Merge remote-tracking branch 'origin/local' into develop

This commit is contained in:
zhouchengrong
2024-04-17 09:40:56 +08:00
7 changed files with 25 additions and 25 deletions

View File

@@ -6,6 +6,7 @@ RUN apt install -y libgl1-mesa-glx
COPY ./requirements.txt /requirements.txt COPY ./requirements.txt /requirements.txt
RUN pip install --upgrade pip RUN pip install --upgrade pip
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
RUN mkdir -p app/logs
RUN pip install gunicorn RUN pip install gunicorn
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
RUN #pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html RUN #pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
@@ -19,4 +20,4 @@ LABEL maintainer="zchengrong@yeah.net" \
name="trinity_aida" name="trinity_aida"
CMD ["gunicorn", "-c", "gunicorn_config.py", "app.main:app" , "-e", "SR_RABBITMQ_QUEUES=SuperResolution-dev" ,"-e", "GI_RABBITMQ_QUEUES=GenerateImage-dev"] CMD ["gunicorn", "-c", "gunicorn_config.py", "app.main:app" , "-e", "SR_RABBITMQ_QUEUES=SuperResolution-local" ,"-e", "GI_RABBITMQ_QUEUES=GenerateImage-local"]

View File

@@ -28,8 +28,8 @@ else:
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
# RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "-local" # 本地测试环境 RABBITMQ_ENV = "-local" # 本地测试环境
settings = Settings() settings = Settings()

View File

@@ -30,9 +30,6 @@ class AttributeRecognition:
self.const = const self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
def __del__(self):
self.triton_client.close()
def get_result(self): def get_result(self):
for sketch in self.request_data: for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse": if sketch['category'] == "Tops" or sketch['category'] == "Blouse":

View File

@@ -10,7 +10,10 @@
import json import json
import logging import logging
import time import time
from io import BytesIO
import cv2
import minio
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
import numpy as np import numpy as np
@@ -20,7 +23,6 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger() logger = logging.getLogger()
@@ -32,26 +34,36 @@ class GenerateImage:
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
if request_data.mode == "txt2img": if request_data.mode == "img2img":
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else: else:
self.image = request_data.image self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.prompt = request_data.prompt
self.mode = request_data.mode self.mode = request_data.mode
self.batch_size = 1 self.batch_size = 1
self.category = request_data.category self.category = request_data.category
self.index = 0 self.index = 0
def __del__(self): def get_image(self, image_url):
self.redis_client.close() # Get data of an object.
self.grpc_client.close() # Read data from response.
self.connection.close() try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
except minio.error.S3Error:
image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image_cv2
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data) self.redis_client.set(self.tasks_id, self.generate_data)
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error): def callback(self, result, error):
if error: if error:

View File

@@ -64,11 +64,6 @@ class GenerateImage:
pass pass
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
@staticmethod @staticmethod
def image_grid(imgs, rows, cols): def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols assert len(imgs) == rows * cols

View File

@@ -29,11 +29,6 @@ class SuperResolution:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
# @RunTime # @RunTime
def read_image(self): def read_image(self):
try: try:

Binary file not shown.