diff --git a/src/heretic/main.py b/src/heretic/main.py index 334f1a1..4446059 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -46,8 +46,11 @@ from .utils import ( def run(): # Enable expandable segments to reduce memory fragmentation on multi-GPU setups. - if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + if ( + "PYTORCH_ALLOC_CONF" not in os.environ + and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ + ): + os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # Modified "Pagga" font from https://budavariam.github.io/asciiart-text/ print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")