-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AIR] Added Ray Logging to MosaicTrainer (#29620)
Added RayLogger to MosaicTrainer to relay all reported information. RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer. Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config. Co-authored-by: Amog Kamsetty <[email protected]> Signed-off-by: ilee300a <[email protected]>
- Loading branch information
Showing
5 changed files
with
210 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Any, Dict, Optional, List | ||
import torch | ||
|
||
from composer.loggers import Logger | ||
from composer.loggers.logger_destination import LoggerDestination | ||
from composer.core.state import State | ||
|
||
from ray.air import session | ||
|
||
|
||
class RayLogger(LoggerDestination): | ||
"""A logger to relay information logged by composer models to ray. | ||
This logger allows utilizing all necessary logging and logged data handling provided | ||
by the Composer library. All the logged information is saved in the data dictionary | ||
every time a new information is logged, but to reduce unnecessary reporting, the | ||
most up-to-date logged information is reported as metrics every batch checkpoint and | ||
epoch checkpoint (see Composer's Event module for more details). | ||
Because ray's metric dataframe will not include new keys that is reported after the | ||
very first report call, any logged information with the keys not included in the | ||
first batch checkpoint would not be retrievable after training. In other words, if | ||
the log level is greater than `LogLevel.BATCH` for some data, they would not be | ||
present in `Result.metrics_dataframe`. To allow preserving those information, the | ||
user can provide keys to be always included in the reported data by using `keys` | ||
argument in the constructor. For `MosaicTrainer`, use | ||
`trainer_init_config['log_keys']` to populate these keys. | ||
Note that in the Event callback functions, we remove unused variables, as this is | ||
practiced in Mosaic's composer library. | ||
Args: | ||
keys: the key values that will be included in the reported metrics. | ||
""" | ||
|
||
def __init__(self, keys: Optional[List[str]] = None) -> None: | ||
self.data = {} | ||
# report at fit end only if there are additional training batches run after the | ||
# last epoch checkpoint report | ||
self.should_report_fit_end = False | ||
if keys: | ||
for key in keys: | ||
self.data[key] = None | ||
|
||
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: | ||
self.data.update(metrics.items()) | ||
for key, val in self.data.items(): | ||
if isinstance(val, torch.Tensor): | ||
self.data[key] = val.item() | ||
|
||
def batch_checkpoint(self, state: State, logger: Logger) -> None: | ||
del logger # unused | ||
self.should_report_fit_end = True | ||
|
||
def epoch_checkpoint(self, state: State, logger: Logger) -> None: | ||
del logger # unused | ||
self.should_report_fit_end = False | ||
session.report(self.data) | ||
|
||
# flush the data | ||
self.data = {} | ||
|
||
def fit_end(self, state: State, logger: Logger) -> None: | ||
# report at close in case the trainer stops in the middle of an epoch. | ||
# this may be double counted with epoch checkpoint. | ||
del logger # unused | ||
if self.should_report_fit_end: | ||
session.report(self.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters