2024-10-03 14:51:21 +08:00
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
import mmcv
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
import tritonclient.http as httpclient
|
2025-12-30 16:49:08 +08:00
|
|
|
|
from minio import Minio
|
|
|
|
|
|
from app.core.config import settings
|
2024-10-03 14:51:21 +08:00
|
|
|
|
from app.core.config import DESIGN_MODEL_URL
|
|
|
|
|
|
from app.schemas.image2sketch import Image2SketchModel
|
2025-12-30 16:49:08 +08:00
|
|
|
|
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
2024-10-03 14:51:21 +08:00
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
2025-12-30 16:49:08 +08:00
|
|
|
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
2024-10-03 14:51:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineArtService:
|
|
|
|
|
|
def __init__(self, request_item):
|
|
|
|
|
|
self.line_style = int(request_item.default_style)
|
|
|
|
|
|
self.image_url = request_item.image_url
|
|
|
|
|
|
self.sketch_bucket = request_item.sketch_bucket
|
|
|
|
|
|
self.sketch_name = request_item.sketch_name
|
|
|
|
|
|
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
|
|
|
|
|
|
|
|
|
|
|
|
def get_result(self):
|
|
|
|
|
|
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
|
|
|
|
|
input_image = self.get_image()
|
|
|
|
|
|
input_img, ori_shape = self.line_art_preprocess(input_image)
|
|
|
|
|
|
transformed_img = input_img.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
|
|
|
|
|
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
|
|
|
|
|
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
|
|
|
|
|
|
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
|
|
|
|
|
|
inference_output1 = results.as_numpy("output__0")
|
|
|
|
|
|
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
|
|
|
|
|
|
|
|
|
|
|
|
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
if self.line_style != 0:
|
|
|
|
|
|
logger.info(self.line_style)
|
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8)
|
|
|
|
|
|
dilated = cv2.erode(line_art_result, kernel, iterations=1)
|
|
|
|
|
|
# 将原图与膨胀后的图像进行混合,使用不同的权重
|
|
|
|
|
|
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
|
|
|
|
|
|
# cv2.imshow("", line_art_result)
|
|
|
|
|
|
# cv2.waitKey(0)
|
|
|
|
|
|
return self.put_image(line_art_result)
|
|
|
|
|
|
|
|
|
|
|
|
def get_image(self):
|
|
|
|
|
|
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
2024-10-07 10:24:01 +08:00
|
|
|
|
# 将其转换为彩色图像
|
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 4:
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
|
|
|
|
|
elif len(image.shape) == 2:
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
2024-10-03 14:51:21 +08:00
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
def put_image(self, image):
|
|
|
|
|
|
try:
|
|
|
|
|
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
2025-12-30 16:49:08 +08:00
|
|
|
|
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes)
|
2024-10-04 17:43:08 +08:00
|
|
|
|
return f"{self.sketch_bucket}/{self.sketch_name}.jpg"
|
2024-10-03 14:51:21 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(e)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def line_art_preprocess(image):
|
|
|
|
|
|
img = mmcv.imread(image)
|
|
|
|
|
|
ori_shape = img.shape[:2]
|
|
|
|
|
|
img_scale_w, img_scale_h = ori_shape
|
|
|
|
|
|
if ori_shape[0] > 1024:
|
|
|
|
|
|
img_scale_w = 1024
|
|
|
|
|
|
if ori_shape[1] > 1024:
|
|
|
|
|
|
img_scale_h = 1024
|
|
|
|
|
|
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
|
|
|
|
|
if ori_shape != (img_scale_w, img_scale_h):
|
|
|
|
|
|
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
|
|
|
|
|
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
|
|
|
|
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
|
|
|
|
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
|
|
|
|
|
return preprocessed_img, ori_shape
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def line_art_postprocess(output, ori_shape):
|
|
|
|
|
|
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
|
|
|
|
|
seg_pred = seg_logit.cpu().numpy()
|
|
|
|
|
|
return seg_pred[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
request_item = Image2SketchModel(
|
2024-10-07 10:24:01 +08:00
|
|
|
|
image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png",
|
2024-10-03 14:51:21 +08:00
|
|
|
|
default_style="4",
|
|
|
|
|
|
sketch_bucket="test",
|
2024-10-07 10:24:01 +08:00
|
|
|
|
sketch_name="test123"
|
2024-10-03 14:51:21 +08:00
|
|
|
|
)
|
|
|
|
|
|
service = LineArtService(request_item)
|
|
|
|
|
|
result_url = service.get_result()
|
|
|
|
|
|
print(result_url)
|