1
This commit is contained in:
25
trellis/pipelines/__init__.py
Normal file
25
trellis/pipelines/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from . import samplers
|
||||
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
||||
from .trellis_text_to_3d import TrellisTextTo3DPipeline
|
||||
|
||||
|
||||
def from_pretrained(path: str):
|
||||
"""
|
||||
Load a pipeline from a model folder or a Hugging Face model hub.
|
||||
|
||||
Args:
|
||||
path: The path to the model. Can be either local path or a Hugging Face model name.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}/pipeline.json"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
config_file = hf_hub_download(path, "pipeline.json")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
return globals()[config['name']].from_pretrained(path)
|
||||
68
trellis/pipelines/base.py
Normal file
68
trellis/pipelines/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import models
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
A base class for pipelines.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
models: dict[str, nn.Module] = None,
|
||||
):
|
||||
if models is None:
|
||||
return
|
||||
self.models = models
|
||||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path: str) -> "Pipeline":
|
||||
"""
|
||||
Load a pretrained model.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}/pipeline.json"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
config_file = hf_hub_download(path, "pipeline.json")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
args = json.load(f)['args']
|
||||
|
||||
_models = {}
|
||||
for k, v in args['models'].items():
|
||||
try:
|
||||
_models[k] = models.from_pretrained(f"{path}/{v}")
|
||||
except:
|
||||
_models[k] = models.from_pretrained(v)
|
||||
|
||||
new_pipeline = Pipeline(_models)
|
||||
new_pipeline._pretrained_args = args
|
||||
return new_pipeline
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
for model in self.models.values():
|
||||
if hasattr(model, 'device'):
|
||||
return model.device
|
||||
for model in self.models.values():
|
||||
if hasattr(model, 'parameters'):
|
||||
return next(model.parameters()).device
|
||||
raise RuntimeError("No device found.")
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
for model in self.models.values():
|
||||
model.to(device)
|
||||
|
||||
def cuda(self) -> None:
|
||||
self.to(torch.device("cuda"))
|
||||
|
||||
def cpu(self) -> None:
|
||||
self.to(torch.device("cpu"))
|
||||
2
trellis/pipelines/samplers/__init__.py
Executable file
2
trellis/pipelines/samplers/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
from .base import Sampler
|
||||
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
|
||||
20
trellis/pipelines/samplers/base.py
Normal file
20
trellis/pipelines/samplers/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import *
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Sampler(ABC):
|
||||
"""
|
||||
A base class for samplers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sample from a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
Reference in New Issue
Block a user