feat generate 迁移
This commit is contained in:
@@ -10,15 +10,10 @@ import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
import tritonclient.grpc as grpcclient
|
||||
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, RABBITMQ_QUEUES, SR_TRITON_URL
|
||||
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL
|
||||
from app.schemas.super_resolution import SuperResolutionModel
|
||||
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
|
||||
@@ -27,7 +22,6 @@ logger = logging.getLogger()
|
||||
|
||||
class SuperResolution:
|
||||
def __init__(self, data):
|
||||
logger.info(f"sr triton service url is : {SR_TRITON_URL}")
|
||||
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.tasks_id = data.sr_tasks_id
|
||||
@@ -39,6 +33,13 @@ class SuperResolution:
|
||||
secret_key=MINIO_SECRET,
|
||||
secure=MINIO_SECURE)
|
||||
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
|
||||
def __del__(self):
|
||||
self.redis_client.close()
|
||||
self.triton_client.close()
|
||||
self.connection.close()
|
||||
|
||||
@RunTime
|
||||
def read_image(self):
|
||||
@@ -46,7 +47,8 @@ class SuperResolution:
|
||||
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
||||
except minio.error.S3Error as e:
|
||||
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||
publish_message(sr_data)
|
||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||
logger.info(f" [x] Sent {sr_data}")
|
||||
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||
@@ -82,10 +84,16 @@ class SuperResolution:
|
||||
)
|
||||
|
||||
ctx = self.infer(inputs)
|
||||
time_out = 120
|
||||
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0:
|
||||
if self.read_tasks_status()['status'] == "REVOKED":
|
||||
time_out = 60
|
||||
while time_out > 0:
|
||||
generate_data = self.read_tasks_status()
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
|
||||
logger.info(f" [x] Sent {generate_data}")
|
||||
break
|
||||
elif generate_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(1)
|
||||
return self.read_tasks_status()
|
||||
@@ -123,7 +131,8 @@ class SuperResolution:
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
output_url = self.upload_img_sr(output, generate_uuid())
|
||||
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
|
||||
publish_message(sr_data)
|
||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||
logger.info(f" [x] Sent {sr_data}")
|
||||
self.redis_client.set(self.tasks_id, sr_data)
|
||||
|
||||
|
||||
@@ -131,20 +140,10 @@ def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
|
||||
publish_message(sr_data)
|
||||
redis_client.set(tasks_id, sr_data)
|
||||
return data
|
||||
|
||||
|
||||
def publish_message(sr_data):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
# 发布消息,并设置回调函数
|
||||
channel.basic_publish(exchange='', routing_key=RABBITMQ_QUEUES, body=sr_data)
|
||||
logger.info(f" [x] Sent {sr_data}")
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
|
||||
service = SuperResolution(request_data)
|
||||
|
||||
Reference in New Issue
Block a user