feat generate 迁移
This commit is contained in:
@@ -58,11 +58,13 @@ RABBITMQ_PARAMS = {
|
|||||||
# SR service config
|
# SR service config
|
||||||
SR_MODEL_NAME = "super_resolution"
|
SR_MODEL_NAME = "super_resolution"
|
||||||
SR_TRITON_URL = "10.1.1.240:10031"
|
SR_TRITON_URL = "10.1.1.240:10031"
|
||||||
|
SR_MINIO_BUCKET = "aida-users"
|
||||||
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local")
|
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", "SuperResolution-local")
|
||||||
|
|
||||||
# GenerateImage service config
|
# GenerateImage service config
|
||||||
GI_MODEL_NAME = '_stable_diffusion'
|
GI_MODEL_NAME = '_stable_diffusion'
|
||||||
GI_MODEL_URL = '10.1.1.240:7001'
|
GI_MODEL_URL = '10.1.1.240:7001'
|
||||||
|
GI_MINIO_BUCKET = "aida-users"
|
||||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}")
|
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage-{RABBITMQ_ENV}")
|
||||||
|
|
||||||
# SEG service config
|
# SEG service config
|
||||||
|
|||||||
@@ -130,8 +130,7 @@ class GenerateImage:
|
|||||||
|
|
||||||
if self.category == "sketch":
|
if self.category == "sketch":
|
||||||
image = remove_background(np.asarray(image))
|
image = remove_background(np.asarray(image))
|
||||||
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}",
|
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
|
||||||
object_name=f"{generate_uuid()}_{i}.png", )
|
|
||||||
url_list.append(image_url)
|
url_list.append(image_url)
|
||||||
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
|
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
|
||||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
|
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def upload_png_sd(image, user_id, category, object_name):
|
|||||||
image.save(image_data, format='PNG')
|
image.save(image_data, format='PNG')
|
||||||
image_data.seek(0)
|
image_data.seek(0)
|
||||||
image_bytes = image_data.read()
|
image_bytes = image_data.read()
|
||||||
image_url = f"aida-users/{minio_client.put_object(f'aida-users', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
image_url = f"aida-users/{minio_client.put_object(f'{GI_MINIO_BUCKET}', f'{user_id}/{category}/{object_name}', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||||
|
|
||||||
return image_url
|
return image_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
import minio.error
|
import minio.error
|
||||||
import pika
|
|
||||||
import redis
|
import redis
|
||||||
import json
|
import json
|
||||||
import cv2
|
import cv2
|
||||||
@@ -12,10 +9,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, SR_RABBITMQ_QUEUES, SR_TRITON_URL
|
from app.core.config import *
|
||||||
from app.schemas.super_resolution import SuperResolutionModel
|
from app.schemas.super_resolution import SuperResolutionModel
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.decorator import RunTime
|
||||||
from app.service.utils.generate_uuid import generate_uuid
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -25,6 +21,7 @@ class SuperResolution:
|
|||||||
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
|
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.tasks_id = data.sr_tasks_id
|
self.tasks_id = data.sr_tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.sr_image_url = data.sr_image_url
|
self.sr_image_url = data.sr_image_url
|
||||||
self.sr_xn = data.sr_xn
|
self.sr_xn = data.sr_xn
|
||||||
self.minio_client = Minio(
|
self.minio_client = Minio(
|
||||||
@@ -108,11 +105,11 @@ class SuperResolution:
|
|||||||
# output_url = self.upload_img_sr(output, generate_uuid())
|
# output_url = self.upload_img_sr(output, generate_uuid())
|
||||||
# return output_url
|
# return output_url
|
||||||
|
|
||||||
def upload_img_sr(self, image, object_name):
|
def upload_img_sr(self, image):
|
||||||
try:
|
try:
|
||||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||||
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
|
||||||
|
image_url = f"aida-users/{res.object_name}"
|
||||||
return image_url
|
return image_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"upload_png_mask runtime exception : {e}")
|
logger.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
@@ -129,7 +126,7 @@ class SuperResolution:
|
|||||||
if output.ndim == 3:
|
if output.ndim == 3:
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
output_url = self.upload_img_sr(output, generate_uuid())
|
output_url = self.upload_img_sr(output)
|
||||||
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
|
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
|
||||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||||
logger.info(f" [x] Sent {sr_data}")
|
logger.info(f" [x] Sent {sr_data}")
|
||||||
|
|||||||
@@ -1,31 +1,2 @@
|
|||||||
import time
|
a = "123-86"
|
||||||
|
print(a[a.rfind('-') + 1:])
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tritonclient.http as httpclient
|
|
||||||
import tritonclient.grpc as grpcclient
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
|
|
||||||
|
|
||||||
sample = cv2.imread("1709713346.806274.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
|
||||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
|
||||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
|
||||||
inputs = [
|
|
||||||
grpcclient.InferInput("input", sample.shape, datatype="FP32")
|
|
||||||
]
|
|
||||||
inputs[0].set_data_from_numpy(sample
|
|
||||||
# , binary_data=True
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
|
||||||
results = triton_client.async_infer(model_name="super_resolution", inputs=inputs)
|
|
||||||
print(time.time() - start_time)
|
|
||||||
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
|
||||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
if output.ndim == 3:
|
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
|
||||||
cv2.imshow("", output)
|
|
||||||
cv2.waitKey(0)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user