feat sketch 提取接口
fix
This commit is contained in:
68
app/service/image2sketch/models/triplet_model.py
Normal file
68
app/service/image2sketch/models/triplet_model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from util.image_pool import ImagePool
|
||||
|
||||
|
||||
class TripletModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
|
||||
if is_train:
|
||||
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
||||
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
self.loss_names = ['G_triplet']
|
||||
self.visual_names = ['x','y']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G']
|
||||
else:
|
||||
self.model_names = ['G']
|
||||
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionL1 = torch.nn.L1Loss()
|
||||
|
||||
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
def set_input(self, input):
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.real_C = input['C'].to(self.device)
|
||||
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
|
||||
|
||||
|
||||
def backward_G(self):
|
||||
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
|
||||
self.loss_G_triplet = self.loss_G_triplet_1
|
||||
|
||||
self.loss_G = self.loss_G_triplet
|
||||
self.loss_G.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.backward_G()
|
||||
self.optimizer_G.step()
|
||||
Reference in New Issue
Block a user