1
This commit is contained in:
93
trellis/utils/dist_utils.py
Normal file
93
trellis/utils/dist_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def read_file_dist(path):
|
||||
"""
|
||||
Read the binary file distributedly.
|
||||
File is only read once by the rank 0 process and broadcasted to other processes.
|
||||
|
||||
Returns:
|
||||
data (io.BytesIO): The binary data read from the file.
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# read file
|
||||
size = torch.LongTensor(1).cuda()
|
||||
if dist.get_rank() == 0:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = torch.ByteTensor(
|
||||
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
||||
).cuda()
|
||||
size[0] = data.shape[0]
|
||||
# broadcast size
|
||||
dist.broadcast(size, src=0)
|
||||
if dist.get_rank() != 0:
|
||||
data = torch.ByteTensor(size[0].item()).cuda()
|
||||
# broadcast data
|
||||
dist.broadcast(data, src=0)
|
||||
# convert to io.BytesIO
|
||||
data = data.cpu().numpy().tobytes()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
else:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
|
||||
|
||||
def unwrap_dist(model):
|
||||
"""
|
||||
Unwrap the model from distributed training.
|
||||
"""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def master_first():
|
||||
"""
|
||||
A context manager that ensures master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def local_master_first():
|
||||
"""
|
||||
A context manager that ensures local master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() % torch.cuda.device_count() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user