Merge branch 'local'

# Conflicts:
#	Dockerfile
#	app/core/config.py
This commit is contained in:
zhouchengrong
2024-04-22 14:32:18 +08:00
9 changed files with 78 additions and 97 deletions

4
.gitignore vendored
View File

@@ -125,7 +125,9 @@ seg_result/
seg_result seg_result
*.png *.png
uwsgi uwsgi
#*.yaml *.yaml
*.yml
Dockerfile
.conf .conf
app/logs app/logs

View File

@@ -1,22 +0,0 @@
FROM python:3.9
ENV TZ=Asia/Shanghai
RUN apt update
RUN apt install -y vim
RUN apt install -y libgl1-mesa-glx
COPY ./requirements.txt /requirements.txt
RUN pip install --upgrade pip
RUN pip install -r requirements.txt
RUN pip install gunicorn
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
RUN #pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
WORKDIR /app
COPY . .
ENV FLASK_APP=manage.py
LABEL maintainer="zchengrong@yeah.net" \
description="My Python 3.9 - trinity aida " \
version="1.0" \
name="trinity_aida"
CMD ["gunicorn", "-c", "gunicorn_config.py", "app.main:app" , "-e", "SR_RABBITMQ_QUEUES=SuperResolution" ,"-e", "GI_RABBITMQ_QUEUES=GenerateImage"]

View File

@@ -27,9 +27,9 @@ else:
LOGS_PATH = "app/logs/" LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "-local" # 本地测试环境 RABBITMQ_ENV = "-local" # 本地测试环境
settings = Settings() settings = Settings()

View File

@@ -30,9 +30,6 @@ class AttributeRecognition:
self.const = const self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}")
def __del__(self):
self.triton_client.close()
def get_result(self): def get_result(self):
for sketch in self.request_data: for sketch in self.request_data:
if sketch['category'] == "Tops" or sketch['category'] == "Blouse": if sketch['category'] == "Tops" or sketch['category'] == "Blouse":

View File

@@ -10,7 +10,10 @@
import json import json
import logging import logging
import time import time
from io import BytesIO
import cv2
import minio
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
import numpy as np import numpy as np
@@ -20,7 +23,6 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger() logger = logging.getLogger()
@@ -32,40 +34,52 @@ class GenerateImage:
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
if request_data.mode == "txt2img": if request_data.mode == "img2img":
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else: else:
self.image = request_data.image self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.prompt = request_data.prompt
self.mode = request_data.mode self.mode = request_data.mode
self.batch_size = 1 self.batch_size = 1
self.category = request_data.category self.category = request_data.category
self.index = 0 self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def __del__(self): def get_image(self, image_url):
self.redis_client.close() # Get data of an object.
self.grpc_client.close() # Read data from response.
self.connection.close() try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
def __call__(self, *args, **kwargs): image_file = BytesIO(response.data)
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
self.redis_client.set(self.tasks_id, self.generate_data) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
except minio.error.S3Error:
image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image_cv2
def callback(self, result, error): def callback(self, result, error):
if error: if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.generate_data['status'] = "FAILURE"
self.redis_client.set(self.tasks_id, generate_data) self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: else:
image_result = result.as_numpy("generated_image")[0] image_result = result.as_numpy("generated_image")[0]
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{image_url}'}) self.generate_data['status'] = "SUCCESS"
self.redis_client.set(self.tasks_id, generate_data) self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
def read_tasks_status(self): def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id)) status_data = self.redis_client.get(self.tasks_id)
return status_data return json.loads(status_data), status_data
def infer(self, inputs): def infer(self, inputs):
return self.grpc_client.async_infer( return self.grpc_client.async_infer(
@@ -75,45 +89,53 @@ class GenerateImage:
) )
def get_result(self): def get_result(self):
prompts = [self.prompt] * self.batch_size try:
modes = [self.mode] * self.batch_size prompts = [self.prompt] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj) input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj) input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj) input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode] inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs) ctx = self.infer(inputs)
time_out = 60 time_out = 60
while time_out > 0: generate_data = None
generate_data = self.read_tasks_status() while time_out > 0:
if generate_data['status'] in ["REVOKED", "FAILURE"]: generate_data, _ = self.read_tasks_status()
ctx.cancel() if generate_data['status'] in ["REVOKED", "FAILURE"]:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) ctx.cancel()
logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") break
break elif generate_data['status'] == "SUCCESS":
elif generate_data['status'] == "SUCCESS": break
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) time_out -= 1
logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") time.sleep(0.1)
break return generate_data
time_out -= 1 except Exception as e:
time.sleep(0.1) self.generate_data['status'] = "FAILURE"
return self.read_tasks_status() self.generate_data['message'] = "failure"
self.generate_data['data'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id): def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data) redis_client.set(tasks_id, generate_data)
return data return data

View File

@@ -64,11 +64,6 @@ class GenerateImage:
pass pass
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
@staticmethod @staticmethod
def image_grid(imgs, rows, cols): def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols assert len(imgs) == rows * cols

View File

@@ -26,14 +26,10 @@ class SuperResolution:
self.sr_xn = data.sr_xn self.sr_xn = data.sr_xn
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.redis_client.expire(self.tasks_id, 600)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
# @RunTime # @RunTime
def read_image(self): def read_image(self):
try: try:

View File

@@ -1,9 +0,0 @@
version: "3"
services:
trinity_aida_local:
build: .
container_name: trinity_aida_local
volumes:
- ./trinity_client_aida:/trinity
ports:
- "10201:4562"

Binary file not shown.