1
This commit is contained in:
286
trellis/trainers/flow_matching/sparse_flow_matching.py
Normal file
286
trellis/trainers/flow_matching/sparse_flow_matching.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from typing import *
|
||||
import os
|
||||
import copy
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from ...modules import sparse as sp
|
||||
from ...utils.general_utils import dict_reduce
|
||||
from ...utils.data_utils import cycle, BalancedResumableSampler
|
||||
from .flow_matching import FlowMatchingTrainer
|
||||
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
||||
from .mixins.text_conditioned import TextConditionedMixin
|
||||
from .mixins.image_conditioned import ImageConditionedMixin
|
||||
|
||||
|
||||
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
|
||||
"""
|
||||
Trainer for sparse diffusion model with flow matching objective.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
"""
|
||||
|
||||
def prepare_dataloader(self, **kwargs):
|
||||
"""
|
||||
Prepare dataloader.
|
||||
"""
|
||||
self.data_sampler = BalancedResumableSampler(
|
||||
self.dataset,
|
||||
shuffle=True,
|
||||
batch_size=self.batch_size_per_gpu,
|
||||
)
|
||||
self.dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size_per_gpu,
|
||||
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
|
||||
sampler=self.data_sampler,
|
||||
)
|
||||
self.data_iterator = cycle(self.dataloader)
|
||||
|
||||
def training_losses(
|
||||
self,
|
||||
x_0: sp.SparseTensor,
|
||||
cond=None,
|
||||
**kwargs
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
Args:
|
||||
x_0: The [N x ... x C] sparse tensor of the inputs.
|
||||
cond: The [N x ...] tensor of additional conditions.
|
||||
kwargs: Additional arguments to pass to the backbone.
|
||||
|
||||
Returns:
|
||||
a dict with the key "loss" containing a tensor of shape [N].
|
||||
may also contain other keys for different terms.
|
||||
"""
|
||||
noise = x_0.replace(torch.randn_like(x_0.feats))
|
||||
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
||||
x_t = self.diffuse(x_0, t, noise=noise)
|
||||
cond = self.get_cond(cond, **kwargs)
|
||||
|
||||
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
||||
assert pred.shape == noise.shape == x_0.shape
|
||||
target = self.get_v(x_0, noise, t)
|
||||
terms = edict()
|
||||
terms["mse"] = F.mse_loss(pred.feats, target.feats)
|
||||
terms["loss"] = terms["mse"]
|
||||
|
||||
# log loss with time bins
|
||||
mse_per_instance = np.array([
|
||||
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
|
||||
for i in range(x_0.shape[0])
|
||||
])
|
||||
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
||||
for i in range(10):
|
||||
if (time_bin == i).sum() != 0:
|
||||
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
||||
|
||||
return terms, {}
|
||||
|
||||
@torch.no_grad()
|
||||
def run_snapshot(
|
||||
self,
|
||||
num_samples: int,
|
||||
batch_size: int,
|
||||
verbose: bool = False,
|
||||
) -> Dict:
|
||||
dataloader = DataLoader(
|
||||
copy.deepcopy(self.dataset),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||
)
|
||||
|
||||
# inference
|
||||
sampler = self.get_sampler()
|
||||
sample_gt = []
|
||||
sample = []
|
||||
cond_vis = []
|
||||
for i in range(0, num_samples, batch_size):
|
||||
batch = min(batch_size, num_samples - i)
|
||||
data = next(iter(dataloader))
|
||||
data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
|
||||
noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
|
||||
sample_gt.append(data['x_0'])
|
||||
cond_vis.append(self.vis_cond(**data))
|
||||
del data['x_0']
|
||||
args = self.get_inference_cond(**data)
|
||||
res = sampler.sample(
|
||||
self.models['denoiser'],
|
||||
noise=noise,
|
||||
**args,
|
||||
steps=50, cfg_strength=3.0, verbose=verbose,
|
||||
)
|
||||
sample.append(res.samples)
|
||||
|
||||
sample_gt = sp.sparse_cat(sample_gt)
|
||||
sample = sp.sparse_cat(sample)
|
||||
sample_dict = {
|
||||
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
||||
'sample': {'value': sample, 'type': 'sample'},
|
||||
}
|
||||
sample_dict.update(dict_reduce(cond_vis, None, {
|
||||
'value': lambda x: torch.cat(x, dim=0),
|
||||
'type': lambda x: x[0],
|
||||
}))
|
||||
|
||||
return sample_dict
|
||||
|
||||
|
||||
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
|
||||
"""
|
||||
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
|
||||
"""
|
||||
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
text_cond_model(str): Text conditioning model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
|
||||
"""
|
||||
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
image_cond_model (str): Image conditioning model.
|
||||
"""
|
||||
pass
|
||||
Reference in New Issue
Block a user