feat: enumerate all available GPUs on startup (#86)
* feat: enumerate all available GPUs on startup * feat: extend device enumeration to all accelerator types
This commit is contained in:
+22
-7
@@ -161,19 +161,34 @@ def run():
|
||||
|
||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]")
|
||||
count = torch.cuda.device_count()
|
||||
print(f"Detected [bold]{count}[/] CUDA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
|
||||
elif is_xpu_available():
|
||||
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]")
|
||||
count = torch.xpu.device_count()
|
||||
print(f"Detected [bold]{count}[/] XPU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||
elif is_mlu_available():
|
||||
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]")
|
||||
count = torch.mlu.device_count()
|
||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
|
||||
elif is_sdaa_available():
|
||||
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]")
|
||||
count = torch.sdaa.device_count()
|
||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
|
||||
elif is_musa_available():
|
||||
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
||||
count = torch.musa.device_count()
|
||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
|
||||
elif is_npu_available():
|
||||
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
|
||||
elif torch.backends.mps.is_available():
|
||||
print("GPU type: [bold]Apple Metal (MPS)[/]")
|
||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||
else:
|
||||
print(
|
||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||
|
||||
Reference in New Issue
Block a user