diff --git a/src/heretic/model.py b/src/heretic/model.py index 3c5c025..52e6add 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -169,18 +169,19 @@ class Model: # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() - for layer_index, layer in enumerate(self.get_layers()): - module_id_to_leaf_name = { - id(module): module_name.split(".")[-1] - for module_name, module in layer.named_modules() - } + module_id_to_full_name = { + id(module): module_name + for module_name, module in self.model.named_modules() + } + for layer_index in range(len(self.get_layers())): for modules in self.get_layer_modules(layer_index).values(): for module in modules: - if id(module) in module_id_to_leaf_name: - target_modules_set.add(module_id_to_leaf_name[id(module)]) + full_name = module_id_to_full_name.get(id(module)) + if full_name is not None: + target_modules_set.add(full_name) - target_modules = list(target_modules_set) + target_modules = sorted(target_modules_set) if self.settings.row_normalization != RowNormalization.FULL: # Rank 1 is sufficient for directional ablation without renormalization. @@ -204,7 +205,10 @@ class Model: # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) - print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") + display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules}) + print( + f"* LoRA adapters initialized (target types: {', '.join(display_targets)})" + ) def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """