import logging import cv2 import mmcv import numpy as np import torch import torch.nn.functional as F import tritonclient.http as httpclient from app.core.config import DESIGN_MODEL_URL from app.schemas.image2sketch import Image2SketchModel from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() class LineArtService: def __init__(self, request_item): self.line_style = int(request_item.default_style) self.image_url = request_item.image_url self.sketch_bucket = request_item.sketch_bucket self.sketch_name = request_item.sketch_name self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)] def get_result(self): client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) input_image = self.get_image() input_img, ori_shape = self.line_art_preprocess(input_image) transformed_img = input_img.astype(np.float32) inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)] results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs) inference_output1 = results.as_numpy("output__0") line_art_result = self.line_art_postprocess(inference_output1, ori_shape) line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8) if self.line_style != 0: logger.info(self.line_style) kernel = np.ones((3, 3), np.uint8) dilated = cv2.erode(line_art_result, kernel, iterations=1) # 将原图与膨胀后的图像进行混合,使用不同的权重 line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0) # cv2.imshow("", line_art_result) # cv2.waitKey(0) return self.put_image(line_art_result) def get_image(self): image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") # 将其转换为彩色图像 if len(image.shape) == 3 and image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) elif len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) return image def put_image(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() oss_upload_image(bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes) return f"{self.sketch_bucket}/{self.sketch_name}.jpg" except Exception as e: logger.warning(e) @staticmethod def line_art_preprocess(image): img = mmcv.imread(image) ori_shape = img.shape[:2] img_scale_w, img_scale_h = ori_shape if ori_shape[0] > 1024: img_scale_w = 1024 if ori_shape[1] > 1024: img_scale_h = 1024 # 如果图片size任意一边 大于 1024, 则会resize 成1024 if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape @staticmethod def line_art_postprocess(output, ori_shape): seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) seg_pred = seg_logit.cpu().numpy() return seg_pred[0] if __name__ == '__main__': request_item = Image2SketchModel( image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png", default_style="4", sketch_bucket="test", sketch_name="test123" ) service = LineArtService(request_item) result_url = service.get_result() print(result_url)