Print elapsed and remaining time
This commit is contained in:
+13
-1
@@ -27,7 +27,7 @@ from rich.traceback import install
|
|||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model
|
from .model import AbliterationParameters, Model
|
||||||
from .utils import get_readme_intro, load_prompts, print
|
from .utils import format_duration, get_readme_intro, load_prompts, print
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
@@ -178,6 +178,7 @@ def run():
|
|||||||
)
|
)
|
||||||
|
|
||||||
trial_index = 0
|
trial_index = 0
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
def objective(trial: optuna.Trial):
|
def objective(trial: optuna.Trial):
|
||||||
nonlocal trial_index
|
nonlocal trial_index
|
||||||
@@ -232,6 +233,17 @@ def run():
|
|||||||
print("* Evaluating...")
|
print("* Evaluating...")
|
||||||
score, kl_divergence, refusals = evaluator.get_score()
|
score, kl_divergence, refusals = evaluator.get_score()
|
||||||
|
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
remaining_time = (elapsed_time / trial_index) * (
|
||||||
|
settings.n_trials - trial_index
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
print(f"[grey50]Elapsed time: [bold]{format_duration(elapsed_time)}[/][/]")
|
||||||
|
if trial_index < settings.n_trials:
|
||||||
|
print(
|
||||||
|
f"[grey50]Estimated remaining time: [bold]{format_duration(remaining_time)}[/][/]"
|
||||||
|
)
|
||||||
|
|
||||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||||
trial.set_user_attr("refusals", refusals)
|
trial.set_user_attr("refusals", refusals)
|
||||||
trial.set_user_attr("parameters", parameters)
|
trial.set_user_attr("parameters", parameters)
|
||||||
|
|||||||
@@ -21,6 +21,19 @@ from .config import DatasetSpecification, Settings
|
|||||||
print = Console(highlight=False).print
|
print = Console(highlight=False).print
|
||||||
|
|
||||||
|
|
||||||
|
def format_duration(seconds: float) -> str:
|
||||||
|
seconds = round(seconds)
|
||||||
|
hours, seconds = divmod(seconds, 3600)
|
||||||
|
minutes, seconds = divmod(seconds, 60)
|
||||||
|
|
||||||
|
if hours > 0:
|
||||||
|
return f"{hours}h {minutes}m"
|
||||||
|
elif minutes > 0:
|
||||||
|
return f"{minutes}m {seconds}s"
|
||||||
|
else:
|
||||||
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
|
||||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||||
dataset = load_dataset(specification.dataset, split=specification.split)
|
dataset = load_dataset(specification.dataset, split=specification.split)
|
||||||
return list(dataset[specification.column])
|
return list(dataset[specification.column])
|
||||||
|
|||||||
Reference in New Issue
Block a user