|
|
@ -328,7 +328,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
|
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
|
|
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
|
|
|
mem_total_torch = mem_reserved
|
|
|
|
mem_total_torch = mem_reserved
|
|
|
|
mem_total = mem_total_cuda + mem_total_torch
|
|
|
|
mem_total = mem_total_cuda
|
|
|
|
|
|
|
|
|
|
|
|
if torch_total_too:
|
|
|
|
if torch_total_too:
|
|
|
|
return (mem_total, mem_total_torch)
|
|
|
|
return (mem_total, mem_total_torch)
|
|
|
|