From 83d79c14ef75e9ce1d28f4f8904551bc743c931b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 15 Apr 2024 18:26:48 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 ++ app/service/generate_image/service.py | 3 +- .../generate_image/utils/upload_sd_image.py | 2 +- app/service/super_resolution/service.py | 15 ++++----- app/service/super_resolution/test.py | 33 ++----------------- 5 files changed, 12 insertions(+), 43 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 9c2682c..0444760 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -58,11 +58,13 @@ RABBITMQ_PARAMS = { # SR service config SR_MODEL_NAME = "super_resolution" SR_TRITON_URL = "10.1.1.240:10031" +SR_MINIO_BUCKET = "aida-users" SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local") # GenerateImage service config GI_MODEL_NAME = '_stable_diffusion' GI_MODEL_URL = '10.1.1.240:7001' +GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}") # SEG service config diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index fe3a9b8..67e8532 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -130,8 +130,7 @@ class GenerateImage: if self.category == "sketch": image = remove_background(np.asarray(image)) - image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", - object_name=f"{generate_uuid()}_{i}.png", ) + image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", ) url_list.append(image_url) generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'}) self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 3209c98..1bf7af9 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -26,7 +26,7 @@ def upload_png_sd(image, user_id, category, object_name): image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - image_url = f"aida-users/{minio_client.put_object(f'aida-users', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + image_url = f"aida-users/{minio_client.put_object(f'{GI_MINIO_BUCKET}', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" return image_url except Exception as e: diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index 08862b9..d6d54ca 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,10 +1,7 @@ import io import logging import time -from io import BytesIO - import minio.error -import pika import redis import json import cv2 @@ -12,10 +9,9 @@ import numpy as np import torch import tritonclient.grpc as grpcclient from minio import Minio -from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL +from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.decorator import RunTime -from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() @@ -25,6 +21,7 @@ class SuperResolution: self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn self.minio_client = Minio( @@ -108,11 +105,11 @@ class SuperResolution: # output_url = self.upload_img_sr(output, generate_uuid()) # return output_url - def upload_img_sr(self, image, object_name): + def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() - image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - + res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') + image_url = f"aida-users/{res.object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") @@ -129,7 +126,7 @@ class SuperResolution: if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) - output_url = self.upload_img_sr(output, generate_uuid()) + output_url = self.upload_img_sr(output) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") diff --git a/app/service/super_resolution/test.py b/app/service/super_resolution/test.py index 14675a7..cfbfcda 100644 --- a/app/service/super_resolution/test.py +++ b/app/service/super_resolution/test.py @@ -1,31 +1,2 @@ -import time - -import cv2 -import numpy as np -import torch -import tritonclient.http as httpclient -import tritonclient.grpc as grpcclient - -from PIL import Image - -triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001") - -sample = cv2.imread("1709713346.806274.png", cv2.IMREAD_COLOR).astype(np.float32) / 255. -sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) -sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() -inputs = [ - grpcclient.InferInput("input", sample.shape, datatype="FP32") -] -inputs[0].set_data_from_numpy(sample - # , binary_data=True - ) -start_time = time.time() -results = triton_client.async_infer(model_name="super_resolution", inputs=inputs) -print(time.time() - start_time) -sr_output = torch.from_numpy(results.as_numpy(f"output")) -output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() -if output.ndim == 3: - output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR -output = (output * 255.0).round().astype(np.uint8) -cv2.imshow("", output) -cv2.waitKey(0) +a = "123-86" +print(a[a.rfind('-') + 1:])