Files
AiDA_Python/app/service/lineart/service.py
zcr c03b7e263e
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:17:31 +08:00

102 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import settings
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class LineArtService:
def __init__(self, request_item):
self.line_style = int(request_item.default_style)
self.image_url = request_item.image_url
self.sketch_bucket = request_item.sketch_bucket
self.sketch_name = request_item.sketch_name
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
def get_result(self):
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
input_image = self.get_image()
input_img, ori_shape = self.line_art_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("output__0")
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
if self.line_style != 0:
logger.info(self.line_style)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.erode(line_art_result, kernel, iterations=1)
# 将原图与膨胀后的图像进行混合,使用不同的权重
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
# cv2.imshow("", line_art_result)
# cv2.waitKey(0)
return self.put_image(line_art_result)
def get_image(self):
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
def put_image(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{self.sketch_name}.jpg"
except Exception as e:
logger.warning(e)
@staticmethod
def line_art_preprocess(image):
img = image
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@staticmethod
def line_art_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
if __name__ == '__main__':
request_item = Image2SketchModel(
image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png",
default_style="4",
sketch_bucket="test",
sketch_name="test123"
)
service = LineArtService(request_item)
result_url = service.get_result()
print(result_url)