13 lines
454 B
Python
13 lines
454 B
Python
from typing import *
|
|
|
|
|
|
class ClassifierFreeGuidanceSamplerMixin:
|
|
"""
|
|
A mixin class for samplers that apply classifier-free guidance.
|
|
"""
|
|
|
|
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
|
|
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
|
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
|
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|