feat: enumerate all available GPUs on startup (#86)
* feat: enumerate all available GPUs on startup * feat: extend device enumeration to all accelerator types
This commit is contained in:
+22
-7
@@ -161,19 +161,34 @@ def run():
|
|||||||
|
|
||||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]")
|
count = torch.cuda.device_count()
|
||||||
|
print(f"Detected [bold]{count}[/] CUDA device(s):")
|
||||||
|
for i in range(count):
|
||||||
|
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
|
||||||
elif is_xpu_available():
|
elif is_xpu_available():
|
||||||
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]")
|
count = torch.xpu.device_count()
|
||||||
|
print(f"Detected [bold]{count}[/] XPU device(s):")
|
||||||
|
for i in range(count):
|
||||||
|
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||||
elif is_mlu_available():
|
elif is_mlu_available():
|
||||||
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]")
|
count = torch.mlu.device_count()
|
||||||
|
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||||
|
for i in range(count):
|
||||||
|
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
|
||||||
elif is_sdaa_available():
|
elif is_sdaa_available():
|
||||||
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]")
|
count = torch.sdaa.device_count()
|
||||||
|
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||||
|
for i in range(count):
|
||||||
|
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
|
||||||
elif is_musa_available():
|
elif is_musa_available():
|
||||||
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
count = torch.musa.device_count()
|
||||||
|
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||||
|
for i in range(count):
|
||||||
|
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
|
||||||
elif is_npu_available():
|
elif is_npu_available():
|
||||||
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
print("GPU type: [bold]Apple Metal (MPS)[/]")
|
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||||
|
|||||||
Reference in New Issue
Block a user