Increased GPU usage

This commit is contained in:
Felipe Daragon
2025-05-15 22:30:23 +01:00
parent 59b6233882
commit badbcc6edf
9 changed files with 239 additions and 223 deletions

View File

@@ -82,12 +82,9 @@ class CPUPrefetcher():
class CUDAPrefetcher():
"""CUDA prefetcher.
"""CUDA (or MPS/CPU) prefetcher.
Ref:
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
It may consume more GPU memory.
Args:
loader: Dataloader.
@@ -98,8 +95,18 @@ class CUDAPrefetcher():
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
# Cross-platform device detection
if opt['num_gpu'] != 0 and torch.cuda.is_available():
self.device = torch.device('cuda')
self.stream = torch.cuda.Stream()
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
self.stream = None
else:
self.device = torch.device('cpu')
self.stream = None
self.preload()
def preload(self):
@@ -108,18 +115,24 @@ class CUDAPrefetcher():
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
if self.stream is not None:
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
else:
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
self.batch[k] = self.batch[k].to(device=self.device)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
if self.stream is not None:
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()
self.preload()