diff --git a/src/heretic/main.py b/src/heretic/main.py index 016c392..3e6f4a6 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -174,9 +174,15 @@ def run(): # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py if torch.cuda.is_available(): count = torch.cuda.device_count() - print(f"Detected [bold]{count}[/] CUDA device(s):") + total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count)) + print( + f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):" + ) for i in range(count): - print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]") + vram = torch.cuda.mem_get_info(i)[1] / (1024**3) + print( + f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)" + ) elif is_xpu_available(): count = torch.xpu.device_count() print(f"Detected [bold]{count}[/] XPU device(s):") diff --git a/src/heretic/utils.py b/src/heretic/utils.py index a0d5f35..288ca0f 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -38,11 +38,17 @@ def print_memory_usage(): p("Resident system RAM", Process().memory_info().rss) if torch.cuda.is_available(): - p("Allocated GPU VRAM", torch.cuda.memory_allocated()) - p("Reserved GPU VRAM", torch.cuda.memory_reserved()) + count = torch.cuda.device_count() + allocated = sum(torch.cuda.memory_allocated(device) for device in range(count)) + reserved = sum(torch.cuda.memory_reserved(device) for device in range(count)) + p("Allocated GPU VRAM", allocated) + p("Reserved GPU VRAM", reserved) elif is_xpu_available(): - p("Allocated XPU memory", torch.xpu.memory_allocated()) - p("Reserved XPU memory", torch.xpu.memory_reserved()) + count = torch.xpu.device_count() + allocated = sum(torch.xpu.memory_allocated(device) for device in range(count)) + reserved = sum(torch.xpu.memory_reserved(device) for device in range(count)) + p("Allocated XPU memory", allocated) + p("Reserved XPU memory", reserved) elif torch.backends.mps.is_available(): p("Allocated MPS memory", torch.mps.current_allocated_memory()) p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())