feat: enumerate all available GPUs on startup (#86)

* feat: enumerate all available GPUs on startup

* feat: extend device enumeration to all accelerator types
This commit is contained in:
michaelh
2025-12-16 13:12:15 +01:00
committed by GitHub
parent 243f821d93
commit 92d0c0d551
+22 -7
View File
@@ -161,19 +161,34 @@ def run():
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
if torch.cuda.is_available(): if torch.cuda.is_available():
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]") count = torch.cuda.device_count()
print(f"Detected [bold]{count}[/] CUDA device(s):")
for i in range(count):
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
elif is_xpu_available(): elif is_xpu_available():
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]") count = torch.xpu.device_count()
print(f"Detected [bold]{count}[/] XPU device(s):")
for i in range(count):
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
elif is_mlu_available(): elif is_mlu_available():
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]") count = torch.mlu.device_count()
print(f"Detected [bold]{count}[/] MLU device(s):")
for i in range(count):
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
elif is_sdaa_available(): elif is_sdaa_available():
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]") count = torch.sdaa.device_count()
print(f"Detected [bold]{count}[/] SDAA device(s):")
for i in range(count):
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
elif is_musa_available(): elif is_musa_available():
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]") count = torch.musa.device_count()
print(f"Detected [bold]{count}[/] MUSA device(s):")
for i in range(count):
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
elif is_npu_available(): elif is_npu_available():
print(f"CANN version: [bold]{torch.version.cann}[/]") print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
print("GPU type: [bold]Apple Metal (MPS)[/]") print("Detected [bold]1[/] MPS device (Apple Metal)")
else: else:
print( print(
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]" "[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"