diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py new file mode 100644 index 0000000..a9aecec --- /dev/null +++ b/app/api/api_generate_image.py @@ -0,0 +1,27 @@ +import logging +from fastapi import APIRouter, BackgroundTasks +from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.service import GenerateImage, infer_cancel + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/generate_image") +def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks): + try: + service = GenerateImage(request_item) + background_tasks.add_task(service.get_result) + code = 200 + message = "access" + except Exception as e: + code = 400 + message = e + logger.warning(e) + return {"code": code, "message": message} + + +@router.get("/generate_cancel/{tasks_id}>") +def generate_image(tasks_id): + result = infer_cancel(tasks_id) + return {"code": 200, "message": result['message'], "data": result['data']} diff --git a/app/api/api_route.py b/app/api/api_route.py index fd299ee..ddb0d60 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -2,8 +2,10 @@ from fastapi import APIRouter from app.api import api_test from app.api import api_super_resolution +from app.api import api_generate_image router = APIRouter() router.include_router(api_test.router, tags=["test"], prefix="/test") -router.include_router(api_super_resolution.router, tags=["api_super_resolution"], prefix="/api") +router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api") +router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api") diff --git a/app/api/api_test.py b/app/api/api_test.py index d6e7dcc..63ef1aa 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,8 +1,6 @@ import logging - from fastapi import APIRouter - -from app.core.config import RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES logger = logging.getLogger() router = APIRouter() @@ -10,6 +8,6 @@ router = APIRouter() @router.get("") def test(): - logger.info(RABBITMQ_QUEUES) + logger.info(SR_RABBITMQ_QUEUES) logger.info("test") - return {"message": RABBITMQ_QUEUES} + return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES} diff --git a/app/core/config.py b/app/core/config.py index 9f625c4..9c2682c 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,59 +19,56 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') +DEBUG = True +ENV = 0 +if DEBUG: + LOGS_PATH = "logs/errors.log" +else: + LOGS_PATH = "app/logs/errors.log" + +RABBITMQ_ENV = "" + +if ENV == 1: + RABBITMQ_ENV = "dev" +elif ENV == 2: + RABBITMQ_ENV = "local" + settings = Settings() -ckpt = 'service/super_resolution_ccsr/weights/real-world_ccsr.ckpt' -config = 'service/super_resolution_ccsr/configs/model/ccsr_stage2.yaml' -steps = 45 -sr_scale = 4 -repeat_times = 1 -tiled = False -tile_size = 512 -tile_stride = 256 -color_fix_type = "adain" -t_max = 0.6667 -t_min = 0.3333 -show_lq = False -skip_if_exist = False -seed = 233 -device = "cuda" -tile_diffusion = False # -tile_diffusion_size = 512 -tile_diffusion_stride = 256 -tile_vae = True -vae_decoder_tile_size = 224 -vae_encoder_tile_size = 1024 -strength = 1 # minio 配置 -sr_bucket = "test" MINIO_IP = "www.minio.aida.com.hk" MINIO_PORT = 9000 MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' +MINIO_SECURE = True # redis 配置 REDIS_HOST = "10.1.1.240" REDIS_PORT = "6379" REDIS_DB = "2" -MINIO_SECURE = True - -SR_MODEL_NAME = "super_resolution" -SR_TRITON_URL = "10.1.1.240:10031" - # rabbitmq config - RABBITMQ_PARAMS = { "host": "18.167.251.121", "port": 5672, "credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'), "virtual_host": "/" } -RABBITMQ_QUEUES = os.getenv("RABBITMQ_QUEUES", "SuperResolution-local") -DEBUG = True -if DEBUG: - LOGS_PATH = "logs/errors.log" -else: - LOGS_PATH = "app/logs/errors.log" +# SR service config +SR_MODEL_NAME = "super_resolution" +SR_TRITON_URL = "10.1.1.240:10031" +SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local") + +# GenerateImage service config +GI_MODEL_NAME = '_stable_diffusion' +GI_MODEL_URL = '10.1.1.240:7001' +GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}") + +# SEG service config +SEG_MODEL_URL = '10.1.1.240:10000' +SEGMENTATION = { + "name": "seg_ocrnet_hr18", + "input": "seg_input__0", + "output": "seg_output__0", +} diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py new file mode 100644 index 0000000..a142d9b --- /dev/null +++ b/app/schemas/generate_image.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class GenerateImageModel(BaseModel): + category: str + content: str + gender: str + image_url: str + mode: int + tasks_id: str + user_id: int + version: str diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py new file mode 100644 index 0000000..fe3a9b8 --- /dev/null +++ b/app/service/generate_image/service.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import numpy as np +import random +import redis +import tritonclient +import tritonclient.grpc as grpc_client +from minio import Minio +import cv2 +from PIL import Image +import time +from app.core.config import * +from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.utils.remove_background import remove_background +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.decorator import RunTime +from app.service.utils.generate_uuid import generate_uuid + +logger = logging.getLogger() + + +class GenerateImage: + def __init__(self, request_data): + self.tasks_id = request_data.tasks_id + self.image_url = request_data.image_url + self.user_id = request_data.user_id + self.content = request_data.content + self.category = request_data.category + self.model_name = f"{self.category}{GI_MODEL_NAME}" + self.mode = request_data.mode + self.version = request_data.version + self.triton_client = grpc_client.InferenceServerClient(url=f"{GI_MODEL_URL}") + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio( + f"{MINIO_IP}:{MINIO_PORT}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + self.samples = 4 # no.of images to generate + self.steps = 24 + self.guidance_scale = 7 + self.seed = random.randint(0, 2000000000) + self.batch_size = 1 + self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) + self.redis_client.set(self.tasks_id, self.generate_data) + + def __del__(self): + self.redis_client.close() + self.triton_client.close() + self.connection.close() + + @staticmethod + def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + @staticmethod + def preprocess_image(image, category): + height, width, _ = image.shape + + if category == "print" or category == "moodboard": + square_size = min(height, width) + start_x = (width - square_size) // 2 + start_y = (height - square_size) // 2 + cropped = image[start_y: start_y + square_size, start_x: start_x + square_size] + resized_image = cv2.resize(cropped, (512, 512)) + + elif category == "sketch": + # below is the way that get "bigger" square image. + max_dimension = max(height, width) + square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255 + start_h = (max_dimension - height) // 2 + start_w = (max_dimension - width) // 2 + square_image[start_h:start_h + height, start_w:start_w + width] = image + resized_image = cv2.resize(square_image, (512, 512)) + + else: + raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!") + + return resized_image + + def get_image(self): + # Get data of an object. + # Read data from response. + try: + response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:]) + img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + img = self.preprocess_image(img, self.category) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except: + img = np.random.randn(512, 512, 3) + return img + + def callback(self, result, error): + if error: + generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) + self.redis_client.set(self.tasks_id, generate_data) + else: + images = result.as_numpy("IMAGES") + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + # for i in range(len(pil_images)): + # pil = pil_images[i] + # pil.save(f'./temp_i2_{i}.png') + # self.image_grid(pil_images, rows, cols) + url_list = [] + for i, image in enumerate(pil_images): + + if self.category == "sketch": + image = remove_background(np.asarray(image)) + image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", + object_name=f"{generate_uuid()}_{i}.png", ) + url_list.append(image_url) + generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'}) + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data) + logger.info(f" [x] Sent {generate_data}") + self.redis_client.set(self.tasks_id, generate_data) + + def read_tasks_status(self): + status_data = json.loads(self.redis_client.get(self.tasks_id)) + logging.info(f"{self.tasks_id} ===> {status_data}") + return status_data + + @RunTime + def get_result(self): + self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version) + self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version) + + image = self.get_image() + + # Input placeholder + prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES") + samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32") + steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32") + guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32") + seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64") + input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16") + images = tritonclient.grpc.InferRequestedOutput(name="IMAGES", + # binary_data=False + ) + mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32") + + # Setting inputs + prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object)) + samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32)) + steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32)) + guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32)) + seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64)) + input_images_in.set_data_from_numpy(image.astype(np.float16)) + mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32)) + + # inference + @RunTime + def infer(): + return self.triton_client.async_infer( + model_name=self.model_name, + model_version=self.version, + inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in], + outputs=[images], + callback=self.callback + ) + + ctx = infer() + time_out = 60 + while time_out > 0: + generate_data = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) + logger.info(f" [x] Sent {generate_data}") + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(1) + return self.read_tasks_status() + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + # request_data = { + # "user_id": 78, + # "image_url": "123_123.png", + # "category": "print", + # "mode": 1, + # "str": "a simple print", + # "version": "1" + # } + request_data = GenerateImageModel( + mode=1, + content='a blouse', + gender='', + user_id=89, + image_url='test/微信图片_20231206133428.jpg', + category='sketch', + version='1', + tasks_id='123456' + ) + server = GenerateImage(request_data) + server.get_result() + # print(infer_cancel(123456)) diff --git a/app/service/generate_image/utils/remove_background.py b/app/service/generate_image/utils/remove_background.py new file mode 100644 index 0000000..abc45a0 --- /dev/null +++ b/app/service/generate_image/utils/remove_background.py @@ -0,0 +1,115 @@ +import cv2 +import mmcv +import numpy as np +import torch +from PIL import Image + +import tritonclient.http as httpclient + +import torch.nn.functional as F + +from app.core.config import * + + +def seg_preprocess(img_path): + img = mmcv.imread(img_path) + ori_shape = img.shape[:2] + img_scale = (224, 224) + scale_factor = [] + img, x, y = mmcv.imresize(img, img_scale, return_scale=True) + scale_factor.append(x) + scale_factor.append(y) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + +def get_mask(image_obj): + pre_mask = None + if len(image_obj.shape) == 2: + image_obj = cv2.cvtColor(image_obj, cv2.COLOR_GRAY2RGB) + if image_obj.shape[2] == 4: # 如果是四通道 mask + pre_mask = image_obj[:, :, 3] + image_obj = image_obj[:, :, :3] + + Contour = get_contours(image_obj) + Mask = np.zeros(image_obj.shape[:2], np.uint8) + if len(Contour): + Max_contour = Contour[0] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + else: + Mask = np.ones(image_obj.shape[:2], np.uint8) * 255 + + if pre_mask is None: + mask = Mask + else: + mask = cv2.bitwise_and(Mask, pre_mask) + return image_obj, mask + + +def get_contours(image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + Edge = cv2.Canny(gray, 10, 150) + kernel = np.ones((5, 5), np.uint8) + Edge = cv2.dilate(Edge, kernel=kernel, iterations=1) + Edge = cv2.erode(Edge, kernel=kernel, iterations=1) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contour = sorted(Contour, key=cv2.contourArea, reverse=True) + return Contour + + +def seg_infer_image(image_obj): + image, ori_shape = seg_preprocess(image_obj) + client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}") + transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), + ] + results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output'])) + seg_result = seg_postprocess(inference_output1, ori_shape) + return seg_result + + +def seg_postprocess(output, ori_shape): + seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_logit = F.softmax(seg_logit, dim=1) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + return seg_pred + + +def remove_background(image): + image_obj, mask = get_mask(image) + seg_result = seg_infer_image(image_obj) + + temp_front = seg_result == 1 + front_mask = (mask * (temp_front + 0).astype(np.uint8)) + temp_back = seg_result == 2 + back_mask = (mask * (temp_back + 0).astype(np.uint8)) + + if len(front_mask.shape) > 2: + front_mask = front_mask[0] + else: + front_mask = front_mask + + if len(back_mask.shape) > 2: + back_mask = back_mask[0] + else: + back_mask = back_mask + + result_mask = front_mask + back_mask + white_background = np.ones_like(image_obj) * 255 + result_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background) + + return Image.fromarray(result_image) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py new file mode 100644 index 0000000..3209c98 --- /dev/null +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :upload_image.py +@Author :周成融 +@Date :2023/8/28 13:49:20 +@detail : +""" +import io +import logging +from minio import Minio + +from app.core.config import * + +minio_client = Minio( + f"{MINIO_IP}:{MINIO_PORT}", + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE) + + +def upload_png_sd(image, user_id, category, object_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + image_url = f"aida-users/{minio_client.put_object(f'aida-users', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index 9e27cdf..08862b9 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -10,15 +10,10 @@ import json import cv2 import numpy as np import torch -import tritonclient.http as httpclient import tritonclient.grpc as grpcclient - -from PIL import Image from minio import Minio - -from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, RABBITMQ_QUEUES, SR_TRITON_URL +from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL from app.schemas.super_resolution import SuperResolutionModel - from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid @@ -27,7 +22,6 @@ logger = logging.getLogger() class SuperResolution: def __init__(self, data): - logger.info(f"sr triton service url is : {SR_TRITON_URL}") self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id @@ -39,6 +33,13 @@ class SuperResolution: secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + + def __del__(self): + self.redis_client.close() + self.triton_client.close() + self.connection.close() @RunTime def read_image(self): @@ -46,7 +47,8 @@ class SuperResolution: image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) - publish_message(sr_data) + self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 @@ -82,10 +84,16 @@ class SuperResolution: ) ctx = self.infer(inputs) - time_out = 120 - while self.read_tasks_status()['status'] == "PENDING" and time_out > 0: - if self.read_tasks_status()['status'] == "REVOKED": + time_out = 60 + while time_out > 0: + generate_data = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() + self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) + logger.info(f" [x] Sent {generate_data}") + break + elif generate_data['status'] == "SUCCESS": + break time_out -= 1 time.sleep(1) return self.read_tasks_status() @@ -123,7 +131,8 @@ class SuperResolution: output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output, generate_uuid()) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) - publish_message(sr_data) + self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + logger.info(f" [x] Sent {sr_data}") self.redis_client.set(self.tasks_id, sr_data) @@ -131,20 +140,10 @@ def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) - publish_message(sr_data) redis_client.set(tasks_id, sr_data) return data -def publish_message(sr_data): - connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - channel = connection.channel() - # 发布消息,并设置回调函数 - channel.basic_publish(exchange='', routing_key=RABBITMQ_QUEUES, body=sr_data) - logger.info(f" [x] Sent {sr_data}") - connection.close() - - if __name__ == '__main__': request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") service = SuperResolution(request_data)