Increased GPU usage
This commit is contained in:
@@ -82,12 +82,9 @@ class CPUPrefetcher():
|
||||
|
||||
|
||||
class CUDAPrefetcher():
|
||||
"""CUDA prefetcher.
|
||||
"""CUDA (or MPS/CPU) prefetcher.
|
||||
|
||||
Ref:
|
||||
https://github.com/NVIDIA/apex/issues/304#
|
||||
|
||||
It may consums more GPU memory.
|
||||
It may consume more GPU memory.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
@@ -98,8 +95,18 @@ class CUDAPrefetcher():
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
self.opt = opt
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
||||
|
||||
# Cross-platform device detection
|
||||
if opt['num_gpu'] != 0 and torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
self.stream = torch.cuda.Stream()
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = torch.device('mps')
|
||||
self.stream = None
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
self.stream = None
|
||||
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
@@ -108,18 +115,24 @@ class CUDAPrefetcher():
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return None
|
||||
# put tensors to gpu
|
||||
with torch.cuda.stream(self.stream):
|
||||
|
||||
if self.stream is not None:
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k, v in self.batch.items():
|
||||
if torch.is_tensor(v):
|
||||
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
for k, v in self.batch.items():
|
||||
if torch.is_tensor(v):
|
||||
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
||||
self.batch[k] = self.batch[k].to(device=self.device)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
self.preload()
|
||||
self.preload()
|
||||
Reference in New Issue
Block a user