diff --git a/src/heretic/config.py b/src/heretic/config.py index e217bba..8aa21c7 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -23,6 +23,11 @@ class DatasetSpecification(BaseModel): class Settings(BaseSettings): model: str = Field(description="Hugging Face model ID, or path to model on disk") + evaluate_model: str | None = Field( + default=None, + description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model", + ) + dtypes: list[str] = Field( description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried." ) diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index 37f25b3..307632c 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -73,7 +73,10 @@ class Evaluator: ** self.settings.kl_score_shape ) - if kl_divergence > self.settings.max_kl_divergence: + if ( + self.settings.evaluate_model is None + and kl_divergence > self.settings.max_kl_divergence + ): print(" [yellow](constraint violation; aborting trial)[/]") return kl_score, kl_divergence, self.base_refusals else: diff --git a/src/heretic/main.py b/src/heretic/main.py index 788e50b..0a7bbc7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -156,6 +156,17 @@ def run(): settings.batch_size = best_batch_size print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") + evaluator = Evaluator(settings, model) + + if settings.evaluate_model is not None: + print() + print(f"Loading model [bold]{settings.evaluate_model}[/]...") + settings.model = settings.evaluate_model + model.reload_model() + print("* Evaluating...") + evaluator.get_score() + return + print() print("Calculating per-layer refusal directions...") print("* Obtaining residuals for good prompts...") @@ -166,8 +177,6 @@ def run(): bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1 ) - evaluator = Evaluator(settings, model) - trial_index = 0 def objective(trial: optuna.Trial):