Skip to content

Commit

Permalink
Adds a FlyteCallback (#23759)
Browse files Browse the repository at this point in the history
* initial flyte callback

* lint

* logs should still be saved to Flyte even if pandas isn't install (unlikely)

* cr - flyte team

* add docs for Flytecallback

* fix doc string - cr sgugger

* Apply suggestions from code review

cr - sgugger fix doc strings

Co-authored-by: Sylvain Gugger <[email protected]>

---------

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
peridotml and sgugger authored May 30, 2023
1 parent 8673166 commit 62ba64b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/source/en/main_classes/callback.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ By default a [`Trainer`] will use the following callbacks:
installed.
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.

The main class that implements callbacks is [`TrainerCallback`]. It gets the
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
Expand Down Expand Up @@ -79,6 +80,8 @@ Here is the list of the available [`TrainerCallback`] in the library:

[[autodoc]] integrations.DagsHubCallback

[[autodoc]] integrations.FlyteCallback

## TrainerCallback

[[autodoc]] TrainerCallback
Expand Down
76 changes: 75 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np

from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
from .utils.versions import importlib_metadata


Expand Down Expand Up @@ -146,6 +146,16 @@ def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None


def is_flytekit_available():
return importlib.util.find_spec("flytekit") is not None


def is_flyte_deck_standard_available():
if not is_flytekit_available():
return False
return importlib.util.find_spec("flytekitplugins.deck") is not None


def hp_params(trial):
if is_optuna_available():
import optuna
Expand Down Expand Up @@ -1537,6 +1547,69 @@ def on_save(self, args, state, control, **kwargs):
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)


class FlyteCallback(TrainerCallback):
"""A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
NOTE: This callback only works within a Flyte task.
Args:
save_log_history (`bool`, *optional*, defaults to `True`):
When set to True, the training logs are saved as a Flyte Deck.
sync_checkpoints (`bool`, *optional*, defaults to `True`):
When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
interruption.
Example:
```python
# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task
@task
def train_hf_transformer():
cp = current_context().checkpoint
trainer = Trainer(..., callbacks=[FlyteCallback()])
output = trainer.train(resume_from_checkpoint=cp.restore())
```
"""

def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
super().__init__()
if not is_flytekit_available():
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")

if not is_flyte_deck_standard_available() or not is_pandas_available():
logger.warning(
"Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
"Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
)
save_log_history = False

from flytekit import current_context

self.cp = current_context().checkpoint
self.save_log_history = save_log_history
self.sync_checkpoints = sync_checkpoints

def on_save(self, args, state, control, **kwargs):
if self.sync_checkpoints and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)

logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
self.cp.save(artifact_path)

def on_train_end(self, args, state, control, **kwargs):
if self.save_log_history:
import pandas as pd
from flytekit import Deck
from flytekitplugins.deck.renderer import TableRenderer

log_history_df = pd.DataFrame(state.log_history)
Deck("Log History", TableRenderer().to_html(log_history_df))


INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
Expand All @@ -1547,6 +1620,7 @@ def on_save(self, args, state, control, **kwargs):
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
}


Expand Down

0 comments on commit 62ba64b

Please sign in to comment.