feat image2sketch triton部署
fix
This commit is contained in:
@@ -1,12 +1,10 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
from app.schemas.image2sketch import Image2SketchModel
|
from app.schemas.image2sketch import Image2SketchModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.image2sketch_2.server import processing_pipeline
|
from app.service.lineart.service import LineArtService
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -30,16 +28,9 @@ def image2sketch(request_item: Image2SketchModel):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
service = LineArtService(request_item)
|
||||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
result_url = service.get_result()
|
||||||
sketch_url = processing_pipeline(
|
|
||||||
image_url=request_item.image_url,
|
|
||||||
thickness=request_item.default_style,
|
|
||||||
sketch_bucket=request_item.sketch_bucket,
|
|
||||||
sketch_name=request_item.sketch_name
|
|
||||||
)
|
|
||||||
logger.info(f"run time is : {time.time() - start_time}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"image2sketch Run Exception @@@@@@:{e}")
|
logger.warning(f"image2sketch Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=sketch_url)
|
return ResponseModel(data=result_url)
|
||||||
|
|||||||
94
app/service/lineart/service.py
Normal file
94
app/service/lineart/service.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tritonclient.http as httpclient
|
||||||
|
|
||||||
|
from app.core.config import DESIGN_MODEL_URL
|
||||||
|
from app.schemas.image2sketch import Image2SketchModel
|
||||||
|
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class LineArtService:
|
||||||
|
def __init__(self, request_item):
|
||||||
|
self.line_style = int(request_item.default_style)
|
||||||
|
self.image_url = request_item.image_url
|
||||||
|
self.sketch_bucket = request_item.sketch_bucket
|
||||||
|
self.sketch_name = request_item.sketch_name
|
||||||
|
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||||
|
input_image = self.get_image()
|
||||||
|
input_img, ori_shape = self.line_art_preprocess(input_image)
|
||||||
|
transformed_img = input_img.astype(np.float32)
|
||||||
|
|
||||||
|
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
|
||||||
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||||
|
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
|
||||||
|
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
|
||||||
|
inference_output1 = results.as_numpy("output__0")
|
||||||
|
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
|
||||||
|
|
||||||
|
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
|
||||||
|
if self.line_style != 0:
|
||||||
|
logger.info(self.line_style)
|
||||||
|
kernel = np.ones((3, 3), np.uint8)
|
||||||
|
dilated = cv2.erode(line_art_result, kernel, iterations=1)
|
||||||
|
# 将原图与膨胀后的图像进行混合,使用不同的权重
|
||||||
|
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
|
||||||
|
# cv2.imshow("", line_art_result)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
return self.put_image(line_art_result)
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def put_image(self, image):
|
||||||
|
try:
|
||||||
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||||
|
oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
|
||||||
|
return f"{self.sketch_bucket}/{self.sketch_name}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(e)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def line_art_preprocess(image):
|
||||||
|
img = mmcv.imread(image)
|
||||||
|
ori_shape = img.shape[:2]
|
||||||
|
img_scale_w, img_scale_h = ori_shape
|
||||||
|
if ori_shape[0] > 1024:
|
||||||
|
img_scale_w = 1024
|
||||||
|
if ori_shape[1] > 1024:
|
||||||
|
img_scale_h = 1024
|
||||||
|
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||||
|
if ori_shape != (img_scale_w, img_scale_h):
|
||||||
|
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||||
|
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||||
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||||
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||||
|
return preprocessed_img, ori_shape
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def line_art_postprocess(output, ori_shape):
|
||||||
|
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||||
|
seg_pred = seg_logit.cpu().numpy()
|
||||||
|
return seg_pred[0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_item = Image2SketchModel(
|
||||||
|
image_url="aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png",
|
||||||
|
default_style="4",
|
||||||
|
sketch_bucket="test",
|
||||||
|
sketch_name="test123.jpg"
|
||||||
|
)
|
||||||
|
service = LineArtService(request_item)
|
||||||
|
result_url = service.get_result()
|
||||||
|
print(result_url)
|
||||||
Reference in New Issue
Block a user