60 lines
2.3 KiB
Python
60 lines
2.3 KiB
Python
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
from ....utils.general_utils import dict_foreach
|
||
|
|
from ....pipelines import samplers
|
||
|
|
|
||
|
|
|
||
|
|
class ClassifierFreeGuidanceMixin:
|
||
|
|
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
self.p_uncond = p_uncond
|
||
|
|
|
||
|
|
def get_cond(self, cond, neg_cond=None, **kwargs):
|
||
|
|
"""
|
||
|
|
Get the conditioning data.
|
||
|
|
"""
|
||
|
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||
|
|
|
||
|
|
if self.p_uncond > 0:
|
||
|
|
# randomly drop the class label
|
||
|
|
def get_batch_size(cond):
|
||
|
|
if isinstance(cond, torch.Tensor):
|
||
|
|
return cond.shape[0]
|
||
|
|
elif isinstance(cond, list):
|
||
|
|
return len(cond)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||
|
|
|
||
|
|
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
||
|
|
B = get_batch_size(ref_cond)
|
||
|
|
|
||
|
|
def select(cond, neg_cond, mask):
|
||
|
|
if isinstance(cond, torch.Tensor):
|
||
|
|
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
||
|
|
return torch.where(mask, neg_cond, cond)
|
||
|
|
elif isinstance(cond, list):
|
||
|
|
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||
|
|
|
||
|
|
mask = list(np.random.rand(B) < self.p_uncond)
|
||
|
|
if not isinstance(cond, dict):
|
||
|
|
cond = select(cond, neg_cond, mask)
|
||
|
|
else:
|
||
|
|
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
||
|
|
|
||
|
|
return cond
|
||
|
|
|
||
|
|
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
||
|
|
"""
|
||
|
|
Get the conditioning data for inference.
|
||
|
|
"""
|
||
|
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||
|
|
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
||
|
|
|
||
|
|
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
||
|
|
"""
|
||
|
|
Get the sampler for the diffusion process.
|
||
|
|
"""
|
||
|
|
return samplers.FlowEulerCfgSampler(self.sigma_min)
|