Skip to content

Commit

Permalink
FIX Use model argument consistently (#1198) (#1205)
Browse files Browse the repository at this point in the history
Some methods were using model and self.model interchangeably. This was
fine, as they were referring to the same object, but is also confusing.
Now model is used consistently.
  • Loading branch information
ngocbh authored Dec 11, 2023
1 parent 00b8200 commit 5c13ea3
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/peft/tuners/adaption_prompt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, model, configs: Dict, adapter_name: str):
self._enabled = True
self.forward = self.model.forward
self.add_adapter(adapter_name, configs[adapter_name])
self._mark_only_adaption_prompts_as_trainable()
self._mark_only_adaption_prompts_as_trainable(self.model)

def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None:
"""Add an adapter with the given name and config."""
Expand Down Expand Up @@ -146,9 +146,9 @@ def _remove_adapted_attentions(self, adapter_name: str) -> None:
setattr(par, config.target_modules, attn.model)
self._cached_adapters[adapter_name] = adapted_attentions

def _mark_only_adaption_prompts_as_trainable(self) -> None:
def _mark_only_adaption_prompts_as_trainable(self, model: nn.Module) -> None:
"""Freeze all parameters of the model except the adaption prompts."""
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if not is_adaption_prompt_trainable(n):
p.requires_grad = False

Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import List, Optional

import torch
from torch import nn
from transformers.pytorch_utils import Conv1D

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
Expand Down Expand Up @@ -147,8 +148,8 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
def _check_target_module_exists(ia3_config, key):
return check_target_module_exists(ia3_config, key)

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand Down
9 changes: 5 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import List, Optional

import torch
from torch import nn
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

Expand Down Expand Up @@ -221,8 +222,8 @@ def _replace_module(self, parent, child_name, new_module, child):
weight = child.qweight if hasattr(child, "qweight") else child.weight
module.to(weight.device)

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand All @@ -232,11 +233,11 @@ def _mark_only_adapters_as_trainable(self) -> None:
continue

if bias == "all":
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in self.model.modules():
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn

return new_module

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand Down
8 changes: 4 additions & 4 deletions src/peft/tuners/mixed/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def _replace_module(self, parent, child_name, new_module, child) -> None:
if "ranknum" in name:
module.to(child.weight.device)

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if not any(prefix in n for prefix in PREFIXES):
p.requires_grad = False

Expand All @@ -142,12 +142,12 @@ def _mark_only_adapters_as_trainable(self) -> None:
continue

if bias == "all":
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
# TODO: check if this is needed for other supported types
for m in self.model.modules():
for m in model.modules():
if isinstance(m, Layers) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
Expand Down
6 changes: 3 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _create_and_replace(
...

@abstractmethod
def _mark_only_adapters_as_trainable(self):
def _mark_only_adapters_as_trainable(self, model: nn.Module):
r"""
A helper method to mark only the adapter layers as trainable (i.e. module.requires_grad = False) This needs to
be overriden for all tuner classes to match the correct key names.
Expand Down Expand Up @@ -252,10 +252,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
f"Please check the target modules and try again."
)

self._mark_only_adapters_as_trainable()
self._mark_only_adapters_as_trainable(model)

if self.peft_config[adapter_name].inference_mode:
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if adapter_name in n:
p.requires_grad = False

Expand Down

0 comments on commit 5c13ea3

Please sign in to comment.