feat image2sketch triton部署

fix
This commit is contained in:
zhouchengrong
2024-10-03 14:51:21 +08:00
parent 5a5bb07f3b
commit 9b415fc502
2 changed files with 98 additions and 13 deletions

View File

@@ -0,0 +1,94 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
class LineArtService:
def __init__(self, request_item):
self.line_style = int(request_item.default_style)
self.image_url = request_item.image_url
self.sketch_bucket = request_item.sketch_bucket
self.sketch_name = request_item.sketch_name
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
def get_result(self):
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
input_image = self.get_image()
input_img, ori_shape = self.line_art_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("output__0")
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
if self.line_style != 0:
logger.info(self.line_style)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.erode(line_art_result, kernel, iterations=1)
# 将原图与膨胀后的图像进行混合,使用不同的权重
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
# cv2.imshow("", line_art_result)
# cv2.waitKey(0)
return self.put_image(line_art_result)
def get_image(self):
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
return image
def put_image(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
return f"{self.sketch_bucket}/{self.sketch_name}"
except Exception as e:
logger.warning(e)
@staticmethod
def line_art_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@staticmethod
def line_art_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
if __name__ == '__main__':
request_item = Image2SketchModel(
image_url="aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png",
default_style="4",
sketch_bucket="test",
sketch_name="test123.jpg"
)
service = LineArtService(request_item)
result_url = service.get_result()
print(result_url)