feat generate 迁移

This commit is contained in:
zhouchengrong
2024-04-15 18:07:25 +08:00
parent b17a1768f8
commit f8493dbdb6
9 changed files with 476 additions and 63 deletions

View File

@@ -0,0 +1,27 @@
import logging
from fastapi import APIRouter, BackgroundTasks
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.service import GenerateImage, infer_cancel
router = APIRouter()
logger = logging.getLogger()
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
try:
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
code = 200
message = "access"
except Exception as e:
code = 400
message = e
logger.warning(e)
return {"code": code, "message": message}
@router.get("/generate_cancel/{tasks_id}>")
def generate_image(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}

View File

@@ -2,8 +2,10 @@ from fastapi import APIRouter
from app.api import api_test from app.api import api_test
from app.api import api_super_resolution from app.api import api_super_resolution
from app.api import api_generate_image
router = APIRouter() router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test") router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["api_super_resolution"], prefix="/api") router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")

View File

@@ -1,8 +1,6 @@
import logging import logging
from fastapi import APIRouter from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES
from app.core.config import RABBITMQ_QUEUES
logger = logging.getLogger() logger = logging.getLogger()
router = APIRouter() router = APIRouter()
@@ -10,6 +8,6 @@ router = APIRouter()
@router.get("") @router.get("")
def test(): def test():
logger.info(RABBITMQ_QUEUES) logger.info(SR_RABBITMQ_QUEUES)
logger.info("test") logger.info("test")
return {"message": RABBITMQ_QUEUES} return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES}

View File

@@ -19,59 +19,56 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
DEBUG = True
ENV = 0
if DEBUG:
LOGS_PATH = "logs/errors.log"
else:
LOGS_PATH = "app/logs/errors.log"
RABBITMQ_ENV = ""
if ENV == 1:
RABBITMQ_ENV = "dev"
elif ENV == 2:
RABBITMQ_ENV = "local"
settings = Settings() settings = Settings()
ckpt = 'service/super_resolution_ccsr/weights/real-world_ccsr.ckpt'
config = 'service/super_resolution_ccsr/configs/model/ccsr_stage2.yaml'
steps = 45
sr_scale = 4
repeat_times = 1
tiled = False
tile_size = 512
tile_stride = 256
color_fix_type = "adain"
t_max = 0.6667
t_min = 0.3333
show_lq = False
skip_if_exist = False
seed = 233
device = "cuda"
tile_diffusion = False #
tile_diffusion_size = 512
tile_diffusion_stride = 256
tile_vae = True
vae_decoder_tile_size = 224
vae_encoder_tile_size = 1024
strength = 1
# minio 配置 # minio 配置
sr_bucket = "test"
MINIO_IP = "www.minio.aida.com.hk" MINIO_IP = "www.minio.aida.com.hk"
MINIO_PORT = 9000 MINIO_PORT = 9000
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# redis 配置 # redis 配置
REDIS_HOST = "10.1.1.240" REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379" REDIS_PORT = "6379"
REDIS_DB = "2" REDIS_DB = "2"
MINIO_SECURE = True
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
# rabbitmq config # rabbitmq config
RABBITMQ_PARAMS = { RABBITMQ_PARAMS = {
"host": "18.167.251.121", "host": "18.167.251.121",
"port": 5672, "port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'), "credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/" "virtual_host": "/"
} }
RABBITMQ_QUEUES = os.getenv("RABBITMQ_QUEUES", "SuperResolution-local")
DEBUG = True # SR service config
if DEBUG: SR_MODEL_NAME = "super_resolution"
LOGS_PATH = "logs/errors.log" SR_TRITON_URL = "10.1.1.240:10031"
else: SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local")
LOGS_PATH = "app/logs/errors.log"
# GenerateImage service config
GI_MODEL_NAME = '_stable_diffusion'
GI_MODEL_URL = '10.1.1.240:7001'
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}")
# SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'
SEGMENTATION = {
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
class GenerateImageModel(BaseModel):
category: str
content: str
gender: str
image_url: str
mode: int
tasks_id: str
user_id: int
version: str

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import numpy as np
import random
import redis
import tritonclient
import tritonclient.grpc as grpc_client
from minio import Minio
import cv2
from PIL import Image
import time
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.remove_background import remove_background
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
self.tasks_id = request_data.tasks_id
self.image_url = request_data.image_url
self.user_id = request_data.user_id
self.content = request_data.content
self.category = request_data.category
self.model_name = f"{self.category}{GI_MODEL_NAME}"
self.mode = request_data.mode
self.version = request_data.version
self.triton_client = grpc_client.InferenceServerClient(url=f"{GI_MODEL_URL}")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.samples = 4 # no.of images to generate
self.steps = 24
self.guidance_scale = 7
self.seed = random.randint(0, 2000000000)
self.batch_size = 1
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data)
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@staticmethod
def preprocess_image(image, category):
height, width, _ = image.shape
if category == "print" or category == "moodboard":
square_size = min(height, width)
start_x = (width - square_size) // 2
start_y = (height - square_size) // 2
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
resized_image = cv2.resize(cropped, (512, 512))
elif category == "sketch":
# below is the way that get "bigger" square image.
max_dimension = max(height, width)
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
start_h = (max_dimension - height) // 2
start_w = (max_dimension - width) // 2
square_image[start_h:start_h + height, start_w:start_w + width] = image
resized_image = cv2.resize(square_image, (512, 512))
else:
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!")
return resized_image
def get_image(self):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = self.preprocess_image(img, self.category)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except:
img = np.random.randn(512, 512, 3)
return img
def callback(self, result, error):
if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, generate_data)
else:
images = result.as_numpy("IMAGES")
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
# for i in range(len(pil_images)):
# pil = pil_images[i]
# pil.save(f'./temp_i2_{i}.png')
# self.image_grid(pil_images, rows, cols)
url_list = []
for i, image in enumerate(pil_images):
if self.category == "sketch":
image = remove_background(np.asarray(image))
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}",
object_name=f"{generate_uuid()}_{i}.png", )
url_list.append(image_url)
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
logger.info(f" [x] Sent {generate_data}")
self.redis_client.set(self.tasks_id, generate_data)
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
@RunTime
def get_result(self):
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version)
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version)
image = self.get_image()
# Input placeholder
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES")
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32")
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
# binary_data=False
)
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
# Setting inputs
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object))
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32))
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
input_images_in.set_data_from_numpy(image.astype(np.float16))
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
# inference
@RunTime
def infer():
return self.triton_client.async_infer(
model_name=self.model_name,
model_version=self.version,
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in],
outputs=[images],
callback=self.callback
)
ctx = infer()
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
# request_data = {
# "user_id": 78,
# "image_url": "123_123.png",
# "category": "print",
# "mode": 1,
# "str": "a simple print",
# "version": "1"
# }
request_data = GenerateImageModel(
mode=1,
content='a blouse',
gender='',
user_id=89,
image_url='test/微信图片_20231206133428.jpg',
category='sketch',
version='1',
tasks_id='123456'
)
server = GenerateImage(request_data)
server.get_result()
# print(infer_cancel(123456))

View File

@@ -0,0 +1,115 @@
import cv2
import mmcv
import numpy as np
import torch
from PIL import Image
import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
def get_mask(image_obj):
pre_mask = None
if len(image_obj.shape) == 2:
image_obj = cv2.cvtColor(image_obj, cv2.COLOR_GRAY2RGB)
if image_obj.shape[2] == 4: # 如果是四通道 mask
pre_mask = image_obj[:, :, 3]
image_obj = image_obj[:, :, :3]
Contour = get_contours(image_obj)
Mask = np.zeros(image_obj.shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(image_obj.shape[:2], np.uint8) * 255
if pre_mask is None:
mask = Mask
else:
mask = cv2.bitwise_and(Mask, pre_mask)
return image_obj, mask
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour
def seg_infer_image(image_obj):
image, ori_shape = seg_preprocess(image_obj)
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
def seg_postprocess(output, ori_shape):
seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_logit = F.softmax(seg_logit, dim=1)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
return seg_pred
def remove_background(image):
image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj)
temp_front = seg_result == 1
front_mask = (mask * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2
back_mask = (mask * (temp_back + 0).astype(np.uint8))
if len(front_mask.shape) > 2:
front_mask = front_mask[0]
else:
front_mask = front_mask
if len(back_mask.shape) > 2:
back_mask = back_mask[0]
else:
back_mask = back_mask
result_mask = front_mask + back_mask
white_background = np.ones_like(image_obj) * 255
result_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background)
return Image.fromarray(result_image)

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File upload_image.py
@Author :周成融
@Date 2023/8/28 13:49:20
@detail
"""
import io
import logging
from minio import Minio
from app.core.config import *
minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
def upload_png_sd(image, user_id, category, object_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
image_url = f"aida-users/{minio_client.put_object(f'aida-users', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -10,15 +10,10 @@ import json
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, RABBITMQ_QUEUES, SR_TRITON_URL
from app.schemas.super_resolution import SuperResolutionModel from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid from app.service.utils.generate_uuid import generate_uuid
@@ -27,7 +22,6 @@ logger = logging.getLogger()
class SuperResolution: class SuperResolution:
def __init__(self, data): def __init__(self, data):
logger.info(f"sr triton service url is : {SR_TRITON_URL}")
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id self.tasks_id = data.sr_tasks_id
@@ -39,6 +33,13 @@ class SuperResolution:
secret_key=MINIO_SECRET, secret_key=MINIO_SECRET,
secure=MINIO_SECURE) secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
@RunTime @RunTime
def read_image(self): def read_image(self):
@@ -46,7 +47,8 @@ class SuperResolution:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
except minio.error.S3Error as e: except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
publish_message(sr_data) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
@@ -82,10 +84,16 @@ class SuperResolution:
) )
ctx = self.infer(inputs) ctx = self.infer(inputs)
time_out = 120 time_out = 60
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0: while time_out > 0:
if self.read_tasks_status()['status'] == "REVOKED": generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel() ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1 time_out -= 1
time.sleep(1) time.sleep(1)
return self.read_tasks_status() return self.read_tasks_status()
@@ -123,7 +131,8 @@ class SuperResolution:
output = (output * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid()) output_url = self.upload_img_sr(output, generate_uuid())
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
publish_message(sr_data) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
self.redis_client.set(self.tasks_id, sr_data) self.redis_client.set(self.tasks_id, sr_data)
@@ -131,20 +140,10 @@ def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
publish_message(sr_data)
redis_client.set(tasks_id, sr_data) redis_client.set(tasks_id, sr_data)
return data return data
def publish_message(sr_data):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
# 发布消息,并设置回调函数
channel.basic_publish(exchange='', routing_key=RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
connection.close()
if __name__ == '__main__': if __name__ == '__main__':
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
service = SuperResolution(request_data) service = SuperResolution(request_data)