1
This commit is contained in:
130
trellis/trainers/vae/sparse_structure_vae.py
Normal file
130
trellis/trainers/vae/sparse_structure_vae.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import *
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from ..basic import BasicTrainer
|
||||
|
||||
|
||||
class SparseStructureVaeTrainer(BasicTrainer):
|
||||
"""
|
||||
Trainer for Sparse Structure VAE.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
|
||||
lambda_kl (float): KL divergence loss weight.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
loss_type='bce',
|
||||
lambda_kl=1e-6,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_type = loss_type
|
||||
self.lambda_kl = lambda_kl
|
||||
|
||||
def training_losses(
|
||||
self,
|
||||
ss: torch.Tensor,
|
||||
**kwargs
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Compute training losses.
|
||||
|
||||
Args:
|
||||
ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
|
||||
|
||||
Returns:
|
||||
a dict with the key "loss" containing a scalar tensor.
|
||||
may also contain other keys for different terms.
|
||||
"""
|
||||
z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
|
||||
logits = self.training_models['decoder'](z)
|
||||
|
||||
terms = edict(loss = 0.0)
|
||||
if self.loss_type == 'bce':
|
||||
terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
|
||||
terms["loss"] = terms["loss"] + terms["bce"]
|
||||
elif self.loss_type == 'l1':
|
||||
terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
|
||||
terms["loss"] = terms["loss"] + terms["l1"]
|
||||
elif self.loss_type == 'dice':
|
||||
logits = F.sigmoid(logits)
|
||||
terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
|
||||
terms["loss"] = terms["loss"] + terms["dice"]
|
||||
else:
|
||||
raise ValueError(f'Invalid loss type {self.loss_type}')
|
||||
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
||||
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
||||
|
||||
return terms, {}
|
||||
|
||||
@torch.no_grad()
|
||||
def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
|
||||
super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_snapshot(
|
||||
self,
|
||||
num_samples: int,
|
||||
batch_size: int,
|
||||
verbose: bool = False,
|
||||
) -> Dict:
|
||||
dataloader = DataLoader(
|
||||
copy.deepcopy(self.dataset),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||
)
|
||||
|
||||
# inference
|
||||
gts = []
|
||||
recons = []
|
||||
for i in range(0, num_samples, batch_size):
|
||||
batch = min(batch_size, num_samples - i)
|
||||
data = next(iter(dataloader))
|
||||
args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
||||
z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
|
||||
logits = self.models['decoder'](z)
|
||||
recon = (logits > 0).long()
|
||||
gts.append(args['ss'])
|
||||
recons.append(recon)
|
||||
|
||||
sample_dict = {
|
||||
'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
|
||||
'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
|
||||
}
|
||||
return sample_dict
|
||||
Reference in New Issue
Block a user