1
This commit is contained in:
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class GuidanceIntervalSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance with interval.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
|
||||
if cfg_interval[0] <= t <= cfg_interval[1]:
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
else:
|
||||
return super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
Reference in New Issue
Block a user