Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model argument issue (#1198) #1205

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/peft/tuners/adaption_prompt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, model, configs: Dict, adapter_name: str):
self._enabled = True
self.forward = self.model.forward
self.add_adapter(adapter_name, configs[adapter_name])
self._mark_only_adaption_prompts_as_trainable()
self._mark_only_adaption_prompts_as_trainable(self.model)

def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None:
"""Add an adapter with the given name and config."""
Expand Down Expand Up @@ -146,9 +146,10 @@ def _remove_adapted_attentions(self, adapter_name: str) -> None:
setattr(par, config.target_modules, attn.model)
self._cached_adapters[adapter_name] = adapted_attentions

def _mark_only_adaption_prompts_as_trainable(self) -> None:
def _mark_only_adaption_prompts_as_trainable(self, model: nn.Module = None) -> None:
"""Freeze all parameters of the model except the adaption prompts."""
for n, p in self.model.named_parameters():
model = model or self.model
for n, p in model.named_parameters():
if not is_adaption_prompt_trainable(n):
p.requires_grad = False

Expand Down
6 changes: 4 additions & 2 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Optional

import torch
from torch import nn
from transformers.pytorch_utils import Conv1D

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
Expand Down Expand Up @@ -146,8 +147,9 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
def _check_target_module_exists(ia3_config, key):
return check_target_module_exists(ia3_config, key)

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module = None) -> None:
model = model or self.model
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand Down
10 changes: 6 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import List, Optional

import torch
from torch import nn
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

Expand Down Expand Up @@ -226,8 +227,9 @@ def _replace_module(self, parent, child_name, new_module, child):
weight = child.qweight if hasattr(child, "qweight") else child.weight
module.to(weight.device)

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module = None) -> None:
model = model or self.model
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand All @@ -237,11 +239,11 @@ def _mark_only_adapters_as_trainable(self) -> None:
continue

if bias == "all":
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in self.model.modules():
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn

return new_module

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
def _mark_only_adapters_as_trainable(self, model: nn.Module = None) -> None:
model = model or self.model
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

Expand Down
6 changes: 3 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _create_and_replace(
...

@abstractmethod
def _mark_only_adapters_as_trainable(self):
def _mark_only_adapters_as_trainable(self, model: nn.Module = None):
r"""
A helper method to mark only the adapter layers as trainable (i.e. module.requires_grad = False) This needs to
be overriden for all tuner classes to match the correct key names.
Expand Down Expand Up @@ -257,10 +257,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
f"Please check the target modules and try again."
)

self._mark_only_adapters_as_trainable()
self._mark_only_adapters_as_trainable(model)

if self.peft_config[adapter_name].inference_mode:
for n, p in self.model.named_parameters():
for n, p in model.named_parameters():
if adapter_name in n:
p.requires_grad = False

Expand Down
Loading