1
This commit is contained in:
63
trellis/trainers/__init__.py
Normal file
63
trellis/trainers/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'BasicTrainer': 'basic',
|
||||
|
||||
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
|
||||
|
||||
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
|
||||
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
|
||||
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
|
||||
|
||||
'FlowMatchingTrainer': 'flow_matching.flow_matching',
|
||||
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
|
||||
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .basic import BasicTrainer
|
||||
|
||||
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
|
||||
|
||||
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
|
||||
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
|
||||
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
|
||||
|
||||
from .flow_matching.flow_matching import (
|
||||
FlowMatchingTrainer,
|
||||
FlowMatchingCFGTrainer,
|
||||
TextConditionedFlowMatchingCFGTrainer,
|
||||
ImageConditionedFlowMatchingCFGTrainer,
|
||||
)
|
||||
|
||||
from .flow_matching.sparse_flow_matching import (
|
||||
SparseFlowMatchingTrainer,
|
||||
SparseFlowMatchingCFGTrainer,
|
||||
TextConditionedSparseFlowMatchingCFGTrainer,
|
||||
ImageConditionedSparseFlowMatchingCFGTrainer,
|
||||
)
|
||||
Reference in New Issue
Block a user