Skip to content

Commit

Permalink
Fix resume_from_checkpoint (#514)
Browse files Browse the repository at this point in the history
add initialization of variable so invalid checkpoints throw a understandable error
  • Loading branch information
hSterz authored Mar 16, 2023
1 parent 0d95a52 commit f041fbf
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 144 deletions.
1 change: 1 addition & 0 deletions src/transformers/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint):
# will be resumed in deepspeed_init
pass
else:
adapter_loaded = False
if os.path.isdir(resume_from_checkpoint):
adapter_loaded = self._load_adapters(resume_from_checkpoint)
self._load_adapter_fusions(resume_from_checkpoint)
Expand Down
336 changes: 192 additions & 144 deletions tests_adapters/test_adapter_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,100 +36,144 @@ def test_resume_training(self):
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
with TemporaryDirectory() as tmpdirname:
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.set_active_adapters("adapter")
model.train_adapter("adapter")

model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.set_active_adapters("adapter")
model.train_adapter("adapter")
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)

training_args = TrainingArguments(
output_dir="./output",
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
# create second model that should resume the training of the first
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.set_active_adapters("adapter")
model_resume.train_adapter("adapter")
trainer_resume = AdapterTrainer(
model=model_resume,
args=TrainingArguments(do_train=True, max_steps=1, output_dir=tmpdirname),
train_dataset=train_dataset,
)
trainer_resume.train(resume_from_checkpoint=True)

self.assertEqual(model.config.adapters.adapters, model_resume.config.adapters.adapters)

for ((k1, v1), (k2, v2)) in zip(trainer.model.state_dict().items(), trainer_resume.model.state_dict().items()):
self.assertEqual(k1, k2)
if "adapter" in k1:
self.assertTrue(torch.equal(v1, v2), k1)

def test_resume_training_invalid_checkpoint(self):

trainer.train()
# create second model that should resume the training of the first
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.set_active_adapters("adapter")
model_resume.train_adapter("adapter")
trainer_resume = AdapterTrainer(
model=model_resume,
args=TrainingArguments(do_train=True, max_steps=1, output_dir="./output"),
train_dataset=train_dataset,
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
trainer_resume.train(resume_from_checkpoint=True)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
with TemporaryDirectory() as tmpdirname:
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.set_active_adapters("adapter")
model.train_adapter("adapter")

self.assertEqual(model.config.adapters.adapters, model_resume.config.adapters.adapters)
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)

for ((k1, v1), (k2, v2)) in zip(trainer.model.state_dict().items(), trainer_resume.model.state_dict().items()):
self.assertEqual(k1, k2)
if "adapter" in k1:
self.assertTrue(torch.equal(v1, v2), k1)
trainer.train()
# create second model that should resume the training of the first
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.set_active_adapters("adapter")
model_resume.train_adapter("adapter")
trainer_resume = AdapterTrainer(
model=model_resume,
args=TrainingArguments(do_train=True, max_steps=1, output_dir=tmpdirname),
train_dataset=train_dataset,
)
with self.assertRaises(Exception):
trainer_resume.train(resume_from_checkpoint=tmpdirname+"_invalid")

def test_resume_training_with_fusion(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
with TemporaryDirectory() as tmpdirname:
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
model.set_active_adapters(Fuse("adapter", "additional_adapter"))
model.train_fusion(Fuse("adapter", "additional_adapter"))

model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
model.set_active_adapters(Fuse("adapter", "additional_adapter"))
model.train_fusion(Fuse("adapter", "additional_adapter"))

training_args = TrainingArguments(
output_dir="./output",
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)

trainer.train()
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
model_resume.set_active_adapters(Fuse("adapter", "additional_adapter"))
model_resume.train_fusion(Fuse("adapter", "additional_adapter"))
trainer_resume = AdapterTrainer(
model=model_resume,
args=TrainingArguments(do_train=True, max_steps=1, output_dir="./output"),
train_dataset=train_dataset,
)
trainer_resume.train(resume_from_checkpoint=True)
trainer.train()
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
model_resume.set_active_adapters(Fuse("adapter", "additional_adapter"))
model_resume.train_fusion(Fuse("adapter", "additional_adapter"))
trainer_resume = AdapterTrainer(
model=model_resume,
args=TrainingArguments(do_train=True, max_steps=1, output_dir=tmpdirname),
train_dataset=train_dataset,
)
trainer_resume.train(resume_from_checkpoint=True)

self.assertEqual(model.config.adapters.adapters, model_resume.config.adapters.adapters)
self.assertEqual(model.config.adapters.adapters, model_resume.config.adapters.adapters)

for ((k1, v1), (k2, v2)) in zip(
trainer.model.to("cpu").state_dict().items(), trainer_resume.model.to("cpu").state_dict().items()
):
self.assertEqual(k1, k2)
if "adapter" in k1:
self.assertTrue(torch.equal(v1, v2), k1)
for ((k1, v1), (k2, v2)) in zip(
trainer.model.to("cpu").state_dict().items(), trainer_resume.model.to("cpu").state_dict().items()
):
self.assertEqual(k1, k2)
if "adapter" in k1:
self.assertTrue(torch.equal(v1, v2), k1)

def test_auto_set_save_adapters(self):
model = BertForSequenceClassification(
Expand All @@ -144,15 +188,16 @@ def test_auto_set_save_adapters(self):
model.add_adapter("adapter2")
model.add_adapter_fusion(Fuse("adapter1", "adapter2"))
model.train_adapter_fusion(Fuse("adapter1", "adapter2"))

training_args = TrainingArguments(
output_dir="./output",
)
trainer = AdapterTrainer(
model=model,
args=training_args,
)
self.assertTrue(trainer.train_adapter_fusion)

with TemporaryDirectory() as tmpdirname:
training_args = TrainingArguments(
output_dir=tmpdirname,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
)
self.assertTrue(trainer.train_adapter_fusion)

@slow
def test_training_load_best_model_at_end_full_model(self):
Expand All @@ -167,27 +212,28 @@ def test_training_load_best_model_at_end_full_model(self):
model.add_adapter("adapter")
model.train_adapter("adapter")

training_args = TrainingArguments(
output_dir="./output",
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
with TemporaryDirectory() as tmpdirname:
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)

trainer.train()
self.assertIsNotNone(trainer.model.active_adapters)
trainer.train()
self.assertIsNotNone(trainer.model.active_adapters)

def test_training_load_best_model_at_end_adapter(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand All @@ -201,25 +247,26 @@ def test_training_load_best_model_at_end_adapter(self):
model.add_adapter("adapter")
model.train_adapter("adapter")

training_args = TrainingArguments(
output_dir="./output",
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = AdapterTrainer(
model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
)
with self.assertLogs(logger) as cm:
trainer.train()
self.assertTrue(any("Loading best adapter(s) from" in line for line in cm.output))
self.assertEqual(Stack("adapter"), trainer.model.active_adapters)
with TemporaryDirectory() as tmpdirname:
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = AdapterTrainer(
model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
)
with self.assertLogs(logger) as cm:
trainer.train()
self.assertTrue(any("Loading best adapter(s) from" in line for line in cm.output))
self.assertEqual(Stack("adapter"), trainer.model.active_adapters)

def test_training_load_best_model_at_end_fusion(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand All @@ -235,25 +282,26 @@ def test_training_load_best_model_at_end_fusion(self):
model.add_adapter_fusion(Fuse("fuse_adapter_1", "fuse_adapter_2"))
model.train_adapter_fusion(Fuse("fuse_adapter_1", "fuse_adapter_2"))

training_args = TrainingArguments(
output_dir="./output",
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = AdapterTrainer(
model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
)
with self.assertLogs(logger) as cm:
trainer.train()
self.assertTrue(any("Loading best adapter fusion(s) from" in line for line in cm.output))
self.assertEqual(Fuse("fuse_adapter_1", "fuse_adapter_2"), trainer.model.active_adapters)
with TemporaryDirectory() as tmpdirname:
training_args = TrainingArguments(
output_dir=tmpdirname,
do_train=True,
learning_rate=0.001,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
load_best_model_at_end=True,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=2,
)
trainer = AdapterTrainer(
model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
)
with self.assertLogs(logger) as cm:
trainer.train()
self.assertTrue(any("Loading best adapter fusion(s) from" in line for line in cm.output))
self.assertEqual(Fuse("fuse_adapter_1", "fuse_adapter_2"), trainer.model.active_adapters)

def test_reloading_prediction_head(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand Down

0 comments on commit f041fbf

Please sign in to comment.