1
This commit is contained in:
2
trellis/pipelines/samplers/__init__.py
Executable file
2
trellis/pipelines/samplers/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
from .base import Sampler
|
||||
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
|
||||
20
trellis/pipelines/samplers/base.py
Normal file
20
trellis/pipelines/samplers/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import *
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Sampler(ABC):
|
||||
"""
|
||||
A base class for samplers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sample from a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
Reference in New Issue
Block a user