1
This commit is contained in:
68
trellis/trainers/flow_matching/mixins/text_conditioned.py
Normal file
68
trellis/trainers/flow_matching/mixins/text_conditioned.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import *
|
||||
import os
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
from ....utils import dist_utils
|
||||
|
||||
|
||||
class TextConditionedMixin:
|
||||
"""
|
||||
Mixin for text-conditioned models.
|
||||
|
||||
Args:
|
||||
text_cond_model: The text conditioning model.
|
||||
"""
|
||||
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.text_cond_model_name = text_cond_model
|
||||
self.text_cond_model = None # the model is init lazily
|
||||
|
||||
def _init_text_cond_model(self):
|
||||
"""
|
||||
Initialize the text conditioning model.
|
||||
"""
|
||||
# load model
|
||||
with dist_utils.local_master_first():
|
||||
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
|
||||
model.eval()
|
||||
model = model.cuda()
|
||||
self.text_cond_model = {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
}
|
||||
self.text_cond_model['null_cond'] = self.encode_text([''])
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text: List[str]) -> torch.Tensor:
|
||||
"""
|
||||
Encode the text.
|
||||
"""
|
||||
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
|
||||
if self.text_cond_model is None:
|
||||
self._init_text_cond_model()
|
||||
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
||||
tokens = encoding['input_ids'].cuda()
|
||||
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
||||
|
||||
return embeddings
|
||||
|
||||
def get_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
cond = self.encode_text(cond)
|
||||
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
||||
cond = super().get_cond(cond, **kwargs)
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
cond = self.encode_text(cond)
|
||||
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
||||
cond = super().get_inference_cond(cond, **kwargs)
|
||||
return cond
|
||||
Reference in New Issue
Block a user