This commit is contained in:
zcr
2026-03-17 11:38:02 +08:00
parent 046be2c797
commit 0571f65793
8 changed files with 1413 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
from typing import *
class GuidanceIntervalSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance with interval.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
return super()._inference_model(model, x_t, t, cond, **kwargs)