Skip to content

Commit

Permalink
Fix LoRA & (IA)³ implementation for Bart & MBart (#518)
Browse files Browse the repository at this point in the history
Fixes a critical issue in the LoRA & (IA)³ implementation of Bart & MBart, where LoRA & (IA)³ weights were not added to the intermediate and output linear layers of the model's decoder blocks.

I.e., adapter configs having intermediate_lora=True or output_lora=True are added incorrectly to (M)Bart models. For LoRA, this does not affect the default config, for (IA)³ it does (intermediate_lora=True).

To ensure correct addition of weights in the future, get_adapter() tests are updated to count the number of modules added per adapter.
  • Loading branch information
calpt authored Mar 20, 2023
1 parent f041fbf commit 09148e0
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def __init__(self, config: BartConfig):
location_key="cross",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config)
self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

self._init_adapter_modules()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def __init__(self, config: MBartConfig):
location_key="cross",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config)
self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

self._init_adapter_modules()
Expand Down
14 changes: 5 additions & 9 deletions tests_adapters/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,23 @@ def run_delete_test(self, model, adapter_config, filter_keys):
has_weights = True
self.assertFalse(has_weights)

def run_get_test(self, model, adapter_config):
def run_get_test(self, model, adapter_config, num_expected_modules):
model.eval()

model.add_adapter("first", config=adapter_config)
model.add_adapter("second", config=adapter_config)
model.set_active_adapters(["first"])
model.to(torch_device)

# adapter is correctly added to config
name = "first"
self.assert_adapter_available(model, name)

first_adapter = model.get_adapter("first")
second_adapter = model.get_adapter("second")
adapter = model.get_adapter("first")

self.assertNotEqual(len(first_adapter), 0)
self.assertEqual(len(first_adapter), len(second_adapter))
self.assertNotEqual(first_adapter, second_adapter)
self.assertNotEqual(len(adapter), 0)
num_found_modules = sum([len(layer_modules) for layer_modules in adapter.values()])
self.assertEqual(num_expected_modules, num_found_modules)

model.delete_adapter("first")
model.delete_adapter("second")

def run_forward_test(self, model, adapter_config):
model.eval()
Expand Down
13 changes: 10 additions & 3 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,17 @@ def test_add_adapter_with_invertible(self):
def test_get_adapter(self):
model = self.get_model()
model.eval()

for adapter_config, _ in self.adapter_configs_to_test:
n_layers = len(list(model.iter_layers()))
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
n_prefix_layers = 2
else:
n_prefix_layers = 1

for adapter_config, n_expected in [(HoulsbyConfig(), n_layers * 2), (MAMConfig(), n_layers + n_prefix_layers)]:
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
self.run_get_test(model, adapter_config)
self.run_get_test(model, adapter_config, n_expected)

def test_add_adapter_multiple_reduction_factors(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_compacter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_delete_compacter(self):

def test_get_compacter(self):
model = self.get_model()
self.run_get_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8))
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), n_layers + 1)

def test_forward_compacter(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_delete_ia3(self):

def test_get_ia3(self):
model = self.get_model()
self.run_get_test(model, IA3Config())
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, IA3Config(intermediate_lora=True, output_lora=True), n_layers * 3)

def test_forward_ia3(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_delete_lora(self):

def test_get_lora(self):
model = self.get_model()
self.run_get_test(model, LoRAConfig())
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, LoRAConfig(intermediate_lora=True, output_lora=True), n_layers * 3)

def test_forward_lora(self):
model = self.get_model()
Expand Down
9 changes: 8 additions & 1 deletion tests_adapters/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ def test_delete_prefix_tuning(self):

def test_get_prefix_tuning(self):
model = self.get_model()
self.run_get_test(model, PrefixTuningConfig(flat=True))
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
n_prefix_layers = 2
else:
n_prefix_layers = 1

self.run_get_test(model, PrefixTuningConfig(flat=True), n_prefix_layers)

def test_forward_prefix_tuning(self):
model = self.get_model()
Expand Down
5 changes: 4 additions & 1 deletion tests_adapters/methods/test_unipelt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def test_delete_unipelt(self):

def test_get_unipelt(self):
model = self.get_model()
self.run_get_test(model, UniPELTConfig())
n_layers = len(list(model.iter_layers()))
# In UniPELT, prefix tuning has gates in every layer
n_prefix_layers = 1.5 * n_layers if model.config.is_encoder_decoder else n_layers
self.run_get_test(model, UniPELTConfig(), n_layers * 2 + n_prefix_layers)

def test_forward_unipelt(self):
model = self.get_model()
Expand Down
13 changes: 7 additions & 6 deletions tests_adapters/test_adapter_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def test_causal_lm_head(self):
model1.add_causal_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

self.run_prediction_head_test(
Expand All @@ -137,8 +135,10 @@ def test_causal_lm_head(self):
# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape[0], 1)
self.assertEqual(generated.shape[0], 1)
self.assertLessEqual(generated.shape[1], seq_output_length)

def test_seq2seq_lm_head(self):
Expand All @@ -149,8 +149,6 @@ def test_seq2seq_lm_head(self):
model1.add_seq2seq_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

# prepare decoder_input_ids similar to how DataCollatorForSeq2Seq does it
Expand All @@ -169,8 +167,11 @@ def test_seq2seq_lm_head(self):
# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape, (1, seq_output_length))
self.assertEqual(generated.shape[0], 1)
self.assertLessEqual(generated.shape[1], seq_output_length)

def test_masked_lm_head(self):
if not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_masked_lm_head"):
Expand Down

0 comments on commit 09148e0

Please sign in to comment.