diff --git a/src/heretic/main.py b/src/heretic/main.py index 1ead245..924adbb 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -161,19 +161,34 @@ def run(): # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py if torch.cuda.is_available(): - print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]") + count = torch.cuda.device_count() + print(f"Detected [bold]{count}[/] CUDA device(s):") + for i in range(count): + print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]") elif is_xpu_available(): - print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]") + count = torch.xpu.device_count() + print(f"Detected [bold]{count}[/] XPU device(s):") + for i in range(count): + print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]") elif is_mlu_available(): - print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]") + count = torch.mlu.device_count() + print(f"Detected [bold]{count}[/] MLU device(s):") + for i in range(count): + print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") elif is_sdaa_available(): - print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]") + count = torch.sdaa.device_count() + print(f"Detected [bold]{count}[/] SDAA device(s):") + for i in range(count): + print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") elif is_musa_available(): - print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]") + count = torch.musa.device_count() + print(f"Detected [bold]{count}[/] MUSA device(s):") + for i in range(count): + print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") elif is_npu_available(): - print(f"CANN version: [bold]{torch.version.cann}[/]") + print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") elif torch.backends.mps.is_available(): - print("GPU type: [bold]Apple Metal (MPS)[/]") + print("Detected [bold]1[/] MPS device (Apple Metal)") else: print( "[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"