fix: report VRAM usage across all GPUs instead of only the default device (#169)

memory_allocated() and memory_reserved() without a device argument only
report GPU 0. Sum across all devices for correct multi-GPU totals and
add total VRAM reporting.
This commit is contained in:
cpagac
2026-02-17 01:23:41 -06:00
committed by GitHub
parent 3a115e280c
commit 4c80c4beb9
2 changed files with 18 additions and 6 deletions
+8 -2
View File
@@ -174,9 +174,15 @@ def run():
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
if torch.cuda.is_available():
count = torch.cuda.device_count()
print(f"Detected [bold]{count}[/] CUDA device(s):")
total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count))
print(
f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):"
)
for i in range(count):
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
vram = torch.cuda.mem_get_info(i)[1] / (1024**3)
print(
f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)"
)
elif is_xpu_available():
count = torch.xpu.device_count()
print(f"Detected [bold]{count}[/] XPU device(s):")
+10 -4
View File
@@ -38,11 +38,17 @@ def print_memory_usage():
p("Resident system RAM", Process().memory_info().rss)
if torch.cuda.is_available():
p("Allocated GPU VRAM", torch.cuda.memory_allocated())
p("Reserved GPU VRAM", torch.cuda.memory_reserved())
count = torch.cuda.device_count()
allocated = sum(torch.cuda.memory_allocated(device) for device in range(count))
reserved = sum(torch.cuda.memory_reserved(device) for device in range(count))
p("Allocated GPU VRAM", allocated)
p("Reserved GPU VRAM", reserved)
elif is_xpu_available():
p("Allocated XPU memory", torch.xpu.memory_allocated())
p("Reserved XPU memory", torch.xpu.memory_reserved())
count = torch.xpu.device_count()
allocated = sum(torch.xpu.memory_allocated(device) for device in range(count))
reserved = sum(torch.xpu.memory_reserved(device) for device in range(count))
p("Allocated XPU memory", allocated)
p("Reserved XPU memory", reserved)
elif torch.backends.mps.is_available():
p("Allocated MPS memory", torch.mps.current_allocated_memory())
p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())