This commit is contained in:
zcr
2026-03-17 11:29:47 +08:00
parent a6d9bac6d0
commit 06659057c3
26 changed files with 3742 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
from typing import *
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import torch
from transformers import AutoTokenizer, CLIPTextModel
from ....utils import dist_utils
class TextConditionedMixin:
"""
Mixin for text-conditioned models.
Args:
text_cond_model: The text conditioning model.
"""
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
super().__init__(*args, **kwargs)
self.text_cond_model_name = text_cond_model
self.text_cond_model = None # the model is init lazily
def _init_text_cond_model(self):
"""
Initialize the text conditioning model.
"""
# load model
with dist_utils.local_master_first():
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
model.eval()
model = model.cuda()
self.text_cond_model = {
'model': model,
'tokenizer': tokenizer,
}
self.text_cond_model['null_cond'] = self.encode_text([''])
@torch.no_grad()
def encode_text(self, text: List[str]) -> torch.Tensor:
"""
Encode the text.
"""
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
if self.text_cond_model is None:
self._init_text_cond_model()
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
tokens = encoding['input_ids'].cuda()
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
return embeddings
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_inference_cond(cond, **kwargs)
return cond

View File

@@ -0,0 +1,286 @@
from typing import *
import os
import copy
import functools
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ...modules import sparse as sp
from ...utils.general_utils import dict_reduce
from ...utils.data_utils import cycle, BalancedResumableSampler
from .flow_matching import FlowMatchingTrainer
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = BalancedResumableSampler(
self.dataset,
shuffle=True,
batch_size=self.batch_size_per_gpu,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def training_losses(
self,
x_0: sp.SparseTensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x ... x C] sparse tensor of the inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = x_0.replace(torch.randn_like(x_0.feats))
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred.feats, target.feats)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
sampler = self.get_sampler()
sample_gt = []
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
sample_gt.append(data['x_0'])
cond_vis.append(self.vis_cond(**data))
del data['x_0']
args = self.get_inference_cond(**data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=50, cfg_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample_gt = sp.sparse_cat(sample_gt)
sample = sp.sparse_cat(sample)
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass

77
trellis/trainers/utils.py Normal file
View File

@@ -0,0 +1,77 @@
import torch.nn as nn
# FP16 utils
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def make_master_params(model_params):
"""
Copy model parameters into a inflated tensor of full-precision parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def model_params_to_master_params(model_params, master_params):
"""
Copy the model parameter data into the master parameters.
"""
master_params[0].detach().copy_(
_flatten_dense_tensors([param.detach().float() for param in model_params])
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
for param, master_param in zip(
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
):
param.detach().copy_(master_param)
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def zero_grad(model_params):
for param in model_params:
if param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
# LR Schedulers
from torch.optim.lr_scheduler import LambdaLR
class LinearWarmupLRScheduler(LambdaLR):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
return float(current_step + 1) / self.warmup_steps
return 1.0