From e2c74bfb3ca41cbb3123565867776221150a41e8 Mon Sep 17 00:00:00 2001 From: MoonRide303 <130458190+MoonRide303@users.noreply.github.com> Date: Sun, 12 Apr 2026 09:17:32 +0200 Subject: [PATCH] fix: support for gemma 4 (#287) --- src/heretic/model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 3c5c025..52e6add 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -169,18 +169,19 @@ class Model: # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() - for layer_index, layer in enumerate(self.get_layers()): - module_id_to_leaf_name = { - id(module): module_name.split(".")[-1] - for module_name, module in layer.named_modules() - } + module_id_to_full_name = { + id(module): module_name + for module_name, module in self.model.named_modules() + } + for layer_index in range(len(self.get_layers())): for modules in self.get_layer_modules(layer_index).values(): for module in modules: - if id(module) in module_id_to_leaf_name: - target_modules_set.add(module_id_to_leaf_name[id(module)]) + full_name = module_id_to_full_name.get(id(module)) + if full_name is not None: + target_modules_set.add(full_name) - target_modules = list(target_modules_set) + target_modules = sorted(target_modules_set) if self.settings.row_normalization != RowNormalization.FULL: # Rank 1 is sufficient for directional ablation without renormalization. @@ -204,7 +205,10 @@ class Model: # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) - print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") + display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules}) + print( + f"* LoRA adapters initialized (target types: {', '.join(display_targets)})" + ) def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """