diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py index d562bee..f630194 100644 --- a/app/api/api_image2sketch.py +++ b/app/api/api_image2sketch.py @@ -1,12 +1,10 @@ -import json import logging -import time from fastapi import APIRouter, HTTPException from app.schemas.image2sketch import Image2SketchModel from app.schemas.response_template import ResponseModel -from app.service.image2sketch_2.server import processing_pipeline +from app.service.lineart.service import LineArtService router = APIRouter() logger = logging.getLogger() @@ -30,16 +28,9 @@ def image2sketch(request_item: Image2SketchModel): } """ try: - start_time = time.time() - logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}") - sketch_url = processing_pipeline( - image_url=request_item.image_url, - thickness=request_item.default_style, - sketch_bucket=request_item.sketch_bucket, - sketch_name=request_item.sketch_name - ) - logger.info(f"run time is : {time.time() - start_time}") + service = LineArtService(request_item) + result_url = service.get_result() except Exception as e: logger.warning(f"image2sketch Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) - return ResponseModel(data=sketch_url) + return ResponseModel(data=result_url) diff --git a/app/service/lineart/service.py b/app/service/lineart/service.py new file mode 100644 index 0000000..e8fc78f --- /dev/null +++ b/app/service/lineart/service.py @@ -0,0 +1,94 @@ +import logging + +import cv2 +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient + +from app.core.config import DESIGN_MODEL_URL +from app.schemas.image2sketch import Image2SketchModel +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +logger = logging.getLogger() + + +class LineArtService: + def __init__(self, request_item): + self.line_style = int(request_item.default_style) + self.image_url = request_item.image_url + self.sketch_bucket = request_item.sketch_bucket + self.sketch_name = request_item.sketch_name + self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)] + + def get_result(self): + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + input_image = self.get_image() + input_img, ori_shape = self.line_art_preprocess(input_image) + transformed_img = input_img.astype(np.float32) + + inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)] + results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs) + inference_output1 = results.as_numpy("output__0") + line_art_result = self.line_art_postprocess(inference_output1, ori_shape) + + line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8) + if self.line_style != 0: + logger.info(self.line_style) + kernel = np.ones((3, 3), np.uint8) + dilated = cv2.erode(line_art_result, kernel, iterations=1) + # 将原图与膨胀后的图像进行混合,使用不同的权重 + line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0) + # cv2.imshow("", line_art_result) + # cv2.waitKey(0) + return self.put_image(line_art_result) + + def get_image(self): + image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + return image + + def put_image(self, image): + try: + image_bytes = cv2.imencode('.jpg', image)[1].tobytes() + oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes) + return f"{self.sketch_bucket}/{self.sketch_name}" + except Exception as e: + logger.warning(e) + + @staticmethod + def line_art_preprocess(image): + img = mmcv.imread(image) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + @staticmethod + def line_art_postprocess(output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +if __name__ == '__main__': + request_item = Image2SketchModel( + image_url="aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png", + default_style="4", + sketch_bucket="test", + sketch_name="test123.jpg" + ) + service = LineArtService(request_item) + result_url = service.get_result() + print(result_url)