diff --git a/release/nightly_tests/dataset/ray_sgd_training.py b/release/nightly_tests/dataset/ray_sgd_training.py index b75b927c0e4a..67c97c391dfc 100644 --- a/release/nightly_tests/dataset/ray_sgd_training.py +++ b/release/nightly_tests/dataset/ray_sgd_training.py @@ -18,33 +18,8 @@ from ray import train from ray.data.aggregate import Mean, Std from ray.data.dataset_pipeline import DatasetPipeline -from ray.train import Trainer, TrainingCallback -from ray.train.callbacks import TBXLoggerCallback - - -# TODO(amogkam): Upstream this into Ray Train. -class MLflowCallback(TrainingCallback): - def __init__(self, config): - self.config = config - - def handle_result(self, results, **info): - # For each result that's being reported by ``train.report()``, - # we get the result from the rank 0 worker (i.e. first worker) and - # report it to MLflow. - rank_zero_results = results[0] - mlflow.log_metrics(rank_zero_results) - - # TODO: fix type hint for logdir - def start_training(self, logdir, **info): - mlflow.start_run(run_name=str(logdir.name)) - mlflow.log_params(config) - - # TODO: Update TrainCallback to provide logdir in finish_training. - self.logdir = logdir - - def finish_training(self, error: bool = False, **info): - # Save the Trainer checkpoints as artifacts to mlflow. - mlflow.log_artifacts(self.logdir) +from ray.train import Trainer +from ray.train.callbacks import MLflowLoggerCallback, TBXLoggerCallback def read_dataset(path: str) -> ray.data.Dataset: @@ -593,7 +568,8 @@ def train(self): os.makedirs(tbx_runs_dir, exist_ok=True) callbacks = [ TBXLoggerCallback(logdir=tbx_runs_dir), - MLflowCallback(config) + MLflowLoggerCallback( + experiment_name="cuj-big-data-training", save_artifact=True) ] # Remove CPU resource so Datasets can be scheduled.