feat generate 迁移

This commit is contained in:
zhouchengrong
2024-04-15 18:26:48 +08:00
parent f8493dbdb6
commit 83d79c14ef
5 changed files with 12 additions and 43 deletions

View File

@@ -58,11 +58,13 @@ RABBITMQ_PARAMS = {
# SR service config # SR service config
SR_MODEL_NAME = "super_resolution" SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031" SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local") SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local")
# GenerateImage service config # GenerateImage service config
GI_MODEL_NAME = '_stable_diffusion' GI_MODEL_NAME = '_stable_diffusion'
GI_MODEL_URL = '10.1.1.240:7001' GI_MODEL_URL = '10.1.1.240:7001'
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}") GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}")
# SEG service config # SEG service config

View File

@@ -130,8 +130,7 @@ class GenerateImage:
if self.category == "sketch": if self.category == "sketch":
image = remove_background(np.asarray(image)) image = remove_background(np.asarray(image))
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
object_name=f"{generate_uuid()}_{i}.png", )
url_list.append(image_url) url_list.append(image_url)
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'}) generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data) self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)

View File

@@ -26,7 +26,7 @@ def upload_png_sd(image, user_id, category, object_name):
image.save(image_data, format='PNG') image.save(image_data, format='PNG')
image_data.seek(0) image_data.seek(0)
image_bytes = image_data.read() image_bytes = image_data.read()
image_url = f"aida-users/{minio_client.put_object(f'aida-users', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" image_url = f"aida-users/{minio_client.put_object(f'{GI_MINIO_BUCKET}', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url return image_url
except Exception as e: except Exception as e:

View File

@@ -1,10 +1,7 @@
import io import io
import logging import logging
import time import time
from io import BytesIO
import minio.error import minio.error
import pika
import redis import redis
import json import json
import cv2 import cv2
@@ -12,10 +9,9 @@ import numpy as np
import torch import torch
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from minio import Minio from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL from app.core.config import *
from app.schemas.super_resolution import SuperResolutionModel from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger() logger = logging.getLogger()
@@ -25,6 +21,7 @@ class SuperResolution:
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id self.tasks_id = data.sr_tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.sr_image_url = data.sr_image_url self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn self.sr_xn = data.sr_xn
self.minio_client = Minio( self.minio_client = Minio(
@@ -108,11 +105,11 @@ class SuperResolution:
# output_url = self.upload_img_sr(output, generate_uuid()) # output_url = self.upload_img_sr(output, generate_uuid())
# return output_url # return output_url
def upload_img_sr(self, image, object_name): def upload_img_sr(self, image):
try: try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes() image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
image_url = f"aida-users/{res.object_name}"
return image_url return image_url
except Exception as e: except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}") logger.warning(f"upload_png_mask runtime exception : {e}")
@@ -129,7 +126,7 @@ class SuperResolution:
if output.ndim == 3: if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid()) output_url = self.upload_img_sr(output)
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}") logger.info(f" [x] Sent {sr_data}")

View File

@@ -1,31 +1,2 @@
import time a = "123-86"
print(a[a.rfind('-') + 1:])
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient
from PIL import Image
triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
sample = cv2.imread("1709713346.806274.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
start_time = time.time()
results = triton_client.async_infer(model_name="super_resolution", inputs=inputs)
print(time.time() - start_time)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
cv2.imshow("", output)
cv2.waitKey(0)