fix: support for gemma 4 (#287)
This commit is contained in:
+13
-9
@@ -169,18 +169,19 @@ class Model:
|
|||||||
# across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers).
|
# across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers).
|
||||||
target_modules_set: set[str] = set()
|
target_modules_set: set[str] = set()
|
||||||
|
|
||||||
for layer_index, layer in enumerate(self.get_layers()):
|
module_id_to_full_name = {
|
||||||
module_id_to_leaf_name = {
|
id(module): module_name
|
||||||
id(module): module_name.split(".")[-1]
|
for module_name, module in self.model.named_modules()
|
||||||
for module_name, module in layer.named_modules()
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
for layer_index in range(len(self.get_layers())):
|
||||||
for modules in self.get_layer_modules(layer_index).values():
|
for modules in self.get_layer_modules(layer_index).values():
|
||||||
for module in modules:
|
for module in modules:
|
||||||
if id(module) in module_id_to_leaf_name:
|
full_name = module_id_to_full_name.get(id(module))
|
||||||
target_modules_set.add(module_id_to_leaf_name[id(module)])
|
if full_name is not None:
|
||||||
|
target_modules_set.add(full_name)
|
||||||
|
|
||||||
target_modules = list(target_modules_set)
|
target_modules = sorted(target_modules_set)
|
||||||
|
|
||||||
if self.settings.row_normalization != RowNormalization.FULL:
|
if self.settings.row_normalization != RowNormalization.FULL:
|
||||||
# Rank 1 is sufficient for directional ablation without renormalization.
|
# Rank 1 is sufficient for directional ablation without renormalization.
|
||||||
@@ -204,7 +205,10 @@ class Model:
|
|||||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||||
|
|
||||||
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
|
display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules})
|
||||||
|
print(
|
||||||
|
f"* LoRA adapters initialized (target types: {', '.join(display_targets)})"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user