95 lines
4.0 KiB
Python
95 lines
4.0 KiB
Python
import logging
|
||
|
||
import cv2
|
||
import mmcv
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import tritonclient.http as httpclient
|
||
|
||
from app.core.config import DESIGN_MODEL_URL
|
||
from app.schemas.image2sketch import Image2SketchModel
|
||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||
|
||
logger = logging.getLogger()
|
||
|
||
|
||
class LineArtService:
|
||
def __init__(self, request_item):
|
||
self.line_style = int(request_item.default_style)
|
||
self.image_url = request_item.image_url
|
||
self.sketch_bucket = request_item.sketch_bucket
|
||
self.sketch_name = request_item.sketch_name
|
||
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
|
||
|
||
def get_result(self):
|
||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||
input_image = self.get_image()
|
||
input_img, ori_shape = self.line_art_preprocess(input_image)
|
||
transformed_img = input_img.astype(np.float32)
|
||
|
||
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
|
||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
|
||
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
|
||
inference_output1 = results.as_numpy("output__0")
|
||
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
|
||
|
||
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
|
||
if self.line_style != 0:
|
||
logger.info(self.line_style)
|
||
kernel = np.ones((3, 3), np.uint8)
|
||
dilated = cv2.erode(line_art_result, kernel, iterations=1)
|
||
# 将原图与膨胀后的图像进行混合,使用不同的权重
|
||
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
|
||
# cv2.imshow("", line_art_result)
|
||
# cv2.waitKey(0)
|
||
return self.put_image(line_art_result)
|
||
|
||
def get_image(self):
|
||
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||
return image
|
||
|
||
def put_image(self, image):
|
||
try:
|
||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||
oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
|
||
return f"{self.sketch_bucket}/{self.sketch_name}"
|
||
except Exception as e:
|
||
logger.warning(e)
|
||
|
||
@staticmethod
|
||
def line_art_preprocess(image):
|
||
img = mmcv.imread(image)
|
||
ori_shape = img.shape[:2]
|
||
img_scale_w, img_scale_h = ori_shape
|
||
if ori_shape[0] > 1024:
|
||
img_scale_w = 1024
|
||
if ori_shape[1] > 1024:
|
||
img_scale_h = 1024
|
||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||
if ori_shape != (img_scale_w, img_scale_h):
|
||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||
return preprocessed_img, ori_shape
|
||
|
||
@staticmethod
|
||
def line_art_postprocess(output, ori_shape):
|
||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||
seg_pred = seg_logit.cpu().numpy()
|
||
return seg_pred[0]
|
||
|
||
|
||
if __name__ == '__main__':
|
||
request_item = Image2SketchModel(
|
||
image_url="aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png",
|
||
default_style="4",
|
||
sketch_bucket="test",
|
||
sketch_name="test123.jpg"
|
||
)
|
||
service = LineArtService(request_item)
|
||
result_url = service.get_result()
|
||
print(result_url)
|