1
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from ....utils.general_utils import dict_foreach
|
||||
from ....pipelines import samplers
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceMixin:
|
||||
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.p_uncond = p_uncond
|
||||
|
||||
def get_cond(self, cond, neg_cond=None, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||
|
||||
if self.p_uncond > 0:
|
||||
# randomly drop the class label
|
||||
def get_batch_size(cond):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
return cond.shape[0]
|
||||
elif isinstance(cond, list):
|
||||
return len(cond)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||
|
||||
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
||||
B = get_batch_size(ref_cond)
|
||||
|
||||
def select(cond, neg_cond, mask):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
||||
return torch.where(mask, neg_cond, cond)
|
||||
elif isinstance(cond, list):
|
||||
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||
|
||||
mask = list(np.random.rand(B) < self.p_uncond)
|
||||
if not isinstance(cond, dict):
|
||||
cond = select(cond, neg_cond, mask)
|
||||
else:
|
||||
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
||||
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
||||
|
||||
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
||||
"""
|
||||
Get the sampler for the diffusion process.
|
||||
"""
|
||||
return samplers.FlowEulerCfgSampler(self.sigma_min)
|
||||
Reference in New Issue
Block a user