MPS support (#5)
* MPS support * oops, added issue tracker. * Delete .beads/issues.jsonl
This commit is contained in:
@@ -88,6 +88,8 @@ def run():
|
|||||||
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
||||||
elif is_npu_available():
|
elif is_npu_available():
|
||||||
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
print(f"GPU type: [bold]Apple Metal (MPS)[/]")
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ def empty_cache():
|
|||||||
torch.sdaa.empty_cache()
|
torch.sdaa.empty_cache()
|
||||||
elif is_musa_available():
|
elif is_musa_available():
|
||||||
torch.musa.empty_cache()
|
torch.musa.empty_cache()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user