diff --git a/src/heretic/main.py b/src/heretic/main.py index 0a7bbc7..5034e3a 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -27,7 +27,7 @@ from rich.traceback import install from .config import Settings from .evaluator import Evaluator from .model import AbliterationParameters, Model -from .utils import get_readme_intro, load_prompts, print +from .utils import format_duration, get_readme_intro, load_prompts, print def run(): @@ -178,6 +178,7 @@ def run(): ) trial_index = 0 + start_time = time.perf_counter() def objective(trial: optuna.Trial): nonlocal trial_index @@ -232,6 +233,17 @@ def run(): print("* Evaluating...") score, kl_divergence, refusals = evaluator.get_score() + elapsed_time = time.perf_counter() - start_time + remaining_time = (elapsed_time / trial_index) * ( + settings.n_trials - trial_index + ) + print() + print(f"[grey50]Elapsed time: [bold]{format_duration(elapsed_time)}[/][/]") + if trial_index < settings.n_trials: + print( + f"[grey50]Estimated remaining time: [bold]{format_duration(remaining_time)}[/][/]" + ) + trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("refusals", refusals) trial.set_user_attr("parameters", parameters) diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 6f0fdfb..6cb07fe 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -21,6 +21,19 @@ from .config import DatasetSpecification, Settings print = Console(highlight=False).print +def format_duration(seconds: float) -> str: + seconds = round(seconds) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + + if hours > 0: + return f"{hours}h {minutes}m" + elif minutes > 0: + return f"{minutes}m {seconds}s" + else: + return f"{seconds}s" + + def load_prompts(specification: DatasetSpecification) -> list[str]: dataset = load_dataset(specification.dataset, split=specification.split) return list(dataset[specification.column])