Files
AiDA_Python/app/service/image2sketch/server.py
zhouchengrong d11beb0107 feat sketch 提取接口
fix
2024-08-14 16:45:34 +08:00

80 lines
3.6 KiB
Python

import logging
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from app.schemas.image2sketch import Image2SketchModel
from app.service.image2sketch.infer import tensor2im
from app.service.image2sketch.models import create_model
from app.service.image2sketch.opt import Config
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
class Image2SketchServer:
def __init__(self, request_data):
self.image_url = request_data.image_url
self.sketch_bucket = request_data.sketch_bucket
self.sketch_name = request_data.sketch_name
self.opt = Config()
self.opt.num_threads = 0 # test code only supports num_threads = 0
self.opt.batch_size = 1 # test code only supports batch_size = 1
self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
self.data = {}
device = torch.device("cuda:0")
self.model = create_model(self.opt)
self.model.setup(self.opt)
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
transform = transforms.Compose(transform_list)
style_img = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testC\20180422151845_stEe4.jpeg").convert('L')
style_img = transform(style_img)
self.data['B'] = style_img
self.data['B'] = self.data['B'].unsqueeze(0).to(device)
A, self.width, self.height = self.get_image(self.image_url)
self.data['A'] = transform(A)
self.data['A'] = self.data['A'].unsqueeze(0).to(device)
def get_result(self):
self.model.set_input(self.data)
self.model.test() # run inference
visuals = self.model.get_current_visuals() # get image results
image_numpy = tensor2im(visuals['content_output'].cpu())
image_bytes = cv2.imencode(".jpg", image_numpy)[1].tobytes()
req = oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
return f"{req.bucket_name}/{req.object_name}"
def get_image(self, image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
image = image.convert('RGB')
width = image.size[0]
height = image.size[1]
return image, width, height
if __name__ == '__main__':
data = Image2SketchModel(image_url="test/real_Dress_790b2c6e370644e134df7abdfe7e54d9.jpg_Img.jpg", sketch_bucket="test", sketch_name="test123.jpg")
server = Image2SketchServer(data)
sketch_url = server.get_result()
print(sketch_url)