diff --git a/src/heretic/main.py b/src/heretic/main.py index f178469..4a810e7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -39,6 +39,7 @@ from .utils import ( get_trial_parameters, load_prompts, print, + empty_cache, ) @@ -195,6 +196,9 @@ def run(): p=2, dim=1, ) + # we don't need the residuals after computing refusal directions + del good_residuals, bad_residuals + empty_cache() trial_index = 0 start_time = time.perf_counter()