Files
AiDA_Python/app/service/lineart/service.py
zhouchengrong 4e420f8ae8 feat
fix      sketch 提取修复没有文件后缀问题
2024-10-04 17:43:08 +08:00

95 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
class LineArtService:
def __init__(self, request_item):
self.line_style = int(request_item.default_style)
self.image_url = request_item.image_url
self.sketch_bucket = request_item.sketch_bucket
self.sketch_name = request_item.sketch_name
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
def get_result(self):
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
input_image = self.get_image()
input_img, ori_shape = self.line_art_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("output__0")
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
if self.line_style != 0:
logger.info(self.line_style)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.erode(line_art_result, kernel, iterations=1)
# 将原图与膨胀后的图像进行混合,使用不同的权重
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
# cv2.imshow("", line_art_result)
# cv2.waitKey(0)
return self.put_image(line_art_result)
def get_image(self):
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
return image
def put_image(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{self.sketch_name}.jpg"
except Exception as e:
logger.warning(e)
@staticmethod
def line_art_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@staticmethod
def line_art_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
if __name__ == '__main__':
request_item = Image2SketchModel(
image_url="aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png",
default_style="4",
sketch_bucket="test",
sketch_name="test123.jpg"
)
service = LineArtService(request_item)
result_url = service.get_result()
print(result_url)