Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LoRA & (IA)³ implementation for Bart & MBart #518

Merged
merged 3 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def __init__(self, config: BartConfig):
location_key="cross",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config)
self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

self._init_adapter_modules()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def __init__(self, config: MBartConfig):
location_key="cross",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config)
self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

self._init_adapter_modules()
Expand Down
14 changes: 5 additions & 9 deletions tests_adapters/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,23 @@ def run_delete_test(self, model, adapter_config, filter_keys):
has_weights = True
self.assertFalse(has_weights)

def run_get_test(self, model, adapter_config):
def run_get_test(self, model, adapter_config, num_expected_modules):
model.eval()

model.add_adapter("first", config=adapter_config)
model.add_adapter("second", config=adapter_config)
model.set_active_adapters(["first"])
model.to(torch_device)

# adapter is correctly added to config
name = "first"
self.assert_adapter_available(model, name)

first_adapter = model.get_adapter("first")
second_adapter = model.get_adapter("second")
adapter = model.get_adapter("first")

self.assertNotEqual(len(first_adapter), 0)
self.assertEqual(len(first_adapter), len(second_adapter))
self.assertNotEqual(first_adapter, second_adapter)
self.assertNotEqual(len(adapter), 0)
num_found_modules = sum([len(layer_modules) for layer_modules in adapter.values()])
self.assertEqual(num_expected_modules, num_found_modules)

model.delete_adapter("first")
model.delete_adapter("second")

def run_forward_test(self, model, adapter_config):
model.eval()
Expand Down
13 changes: 10 additions & 3 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,17 @@ def test_add_adapter_with_invertible(self):
def test_get_adapter(self):
model = self.get_model()
model.eval()

for adapter_config, _ in self.adapter_configs_to_test:
n_layers = len(list(model.iter_layers()))
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
n_prefix_layers = 2
else:
n_prefix_layers = 1

for adapter_config, n_expected in [(HoulsbyConfig(), n_layers * 2), (MAMConfig(), n_layers + n_prefix_layers)]:
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
self.run_get_test(model, adapter_config)
self.run_get_test(model, adapter_config, n_expected)

def test_add_adapter_multiple_reduction_factors(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_compacter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_delete_compacter(self):

def test_get_compacter(self):
model = self.get_model()
self.run_get_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8))
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), n_layers + 1)

def test_forward_compacter(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_delete_ia3(self):

def test_get_ia3(self):
model = self.get_model()
self.run_get_test(model, IA3Config())
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, IA3Config(intermediate_lora=True, output_lora=True), n_layers * 3)

def test_forward_ia3(self):
model = self.get_model()
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_delete_lora(self):

def test_get_lora(self):
model = self.get_model()
self.run_get_test(model, LoRAConfig())
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, LoRAConfig(intermediate_lora=True, output_lora=True), n_layers * 3)

def test_forward_lora(self):
model = self.get_model()
Expand Down
9 changes: 8 additions & 1 deletion tests_adapters/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ def test_delete_prefix_tuning(self):

def test_get_prefix_tuning(self):
model = self.get_model()
self.run_get_test(model, PrefixTuningConfig(flat=True))
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
n_prefix_layers = 2
else:
n_prefix_layers = 1

self.run_get_test(model, PrefixTuningConfig(flat=True), n_prefix_layers)

def test_forward_prefix_tuning(self):
model = self.get_model()
Expand Down
5 changes: 4 additions & 1 deletion tests_adapters/methods/test_unipelt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def test_delete_unipelt(self):

def test_get_unipelt(self):
model = self.get_model()
self.run_get_test(model, UniPELTConfig())
n_layers = len(list(model.iter_layers()))
# In UniPELT, prefix tuning has gates in every layer
n_prefix_layers = 1.5 * n_layers if model.config.is_encoder_decoder else n_layers
self.run_get_test(model, UniPELTConfig(), n_layers * 2 + n_prefix_layers)

def test_forward_unipelt(self):
model = self.get_model()
Expand Down
13 changes: 7 additions & 6 deletions tests_adapters/test_adapter_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def test_causal_lm_head(self):
model1.add_causal_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

self.run_prediction_head_test(
Expand All @@ -137,8 +135,10 @@ def test_causal_lm_head(self):
# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape[0], 1)
self.assertEqual(generated.shape[0], 1)
self.assertLessEqual(generated.shape[1], seq_output_length)

def test_seq2seq_lm_head(self):
Expand All @@ -149,8 +149,6 @@ def test_seq2seq_lm_head(self):
model1.add_seq2seq_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

# prepare decoder_input_ids similar to how DataCollatorForSeq2Seq does it
Expand All @@ -169,8 +167,11 @@ def test_seq2seq_lm_head(self):
# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape, (1, seq_output_length))
self.assertEqual(generated.shape[0], 1)
self.assertLessEqual(generated.shape[1], seq_output_length)

def test_masked_lm_head(self):
if not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_masked_lm_head"):
Expand Down