69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
import torch
|
|
from .base_model import BaseModel
|
|
from . import networks
|
|
from util.image_pool import ImagePool
|
|
|
|
|
|
class TripletModel(BaseModel):
|
|
|
|
@staticmethod
|
|
def modify_commandline_options(parser, is_train=True):
|
|
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
|
|
if is_train:
|
|
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
|
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
|
|
|
return parser
|
|
|
|
def __init__(self, opt):
|
|
|
|
BaseModel.__init__(self, opt)
|
|
|
|
self.loss_names = ['G_triplet']
|
|
self.visual_names = ['x','y']
|
|
|
|
if self.isTrain:
|
|
self.model_names = ['G']
|
|
else:
|
|
self.model_names = ['G']
|
|
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
|
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
|
|
|
|
if self.isTrain:
|
|
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
|
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
|
|
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
|
self.criterionL1 = torch.nn.L1Loss()
|
|
|
|
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
|
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
self.optimizers.append(self.optimizer_G)
|
|
|
|
def set_input(self, input):
|
|
AtoB = self.opt.direction == 'AtoB'
|
|
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
|
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
|
self.real_C = input['C'].to(self.device)
|
|
|
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
|
|
|
|
|
|
|
def forward(self):
|
|
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
|
|
|
|
|
|
def backward_G(self):
|
|
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
|
|
self.loss_G_triplet = self.loss_G_triplet_1
|
|
|
|
self.loss_G = self.loss_G_triplet
|
|
self.loss_G.backward()
|
|
|
|
def optimize_parameters(self):
|
|
self.optimizer_G.zero_grad()
|
|
self.backward_G()
|
|
self.optimizer_G.step()
|