Skip to content

Commit

Permalink
Move loading best adapter to the trainer class (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaBeHen authored Feb 9, 2023
1 parent 2af89bd commit ee5f02b
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions src/transformers/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,28 @@ def _load_heads(self, resume_from_checkpoint):
):
self.model.load_head(os.path.join(resume_from_checkpoint, file_name))

def _load_best_model(self):
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
logger.info(
f"Loading best adapter(s) from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
# attempt to re-load all adapters from checkpoint
for adapter in model.config.adapters.adapters:
adapter_dir = os.path.join(self.state.best_model_checkpoint, adapter)
if os.path.exists(adapter_dir):
model.load_adapter(adapter_dir)
if self.train_adapter_fusion:
logger.info(
f"Loading best adapter fusion(s) from {self.state.best_model_checkpoint} (score:"
f" {self.state.best_metric})."
)
# attempt to re-load all adapter fusions from checkpoint
for fusion in model.config.adapters.fusions:
fusion_dir = os.path.join(self.state.best_model_checkpoint, fusion)
if os.path.exists(fusion_dir):
model.load_adapter_fusion(fusion_dir)
model.to(self.args.device)


class AdapterTrainerCallback(TrainerCallback):
def __init__(self, trainer):
Expand All @@ -232,27 +254,6 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
" method"
)

def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs.pop("model")
if args.load_best_model_at_end and state.best_model_checkpoint is not None:

logger.info(f"Loading best adapter(s) from {state.best_model_checkpoint} (score: {state.best_metric}).")
# attempt to re-load all adapters from checkpoint
for adapter in model.config.adapters.adapters:
adapter_dir = os.path.join(state.best_model_checkpoint, adapter)
if os.path.exists(adapter_dir):
model.load_adapter(adapter_dir)
if self.trainer.train_adapter_fusion:
logger.info(
f"Loading best adapter fusion(s) from {state.best_model_checkpoint} (score: {state.best_metric})."
)
# attempt to re-load all adapter fusions from checkpoint
for fusion in model.config.adapters.fusions:
fusion_dir = os.path.join(state.best_model_checkpoint, fusion)
if os.path.exists(fusion_dir):
model.load_adapter_fusion(fusion_dir)
model.to(args.device)

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# apply adapter fusion weight regularization on the value matrix
model = kwargs.pop("model")
Expand Down

0 comments on commit ee5f02b

Please sign in to comment.