Skip to content

Commit

Permalink
[BigDataTraining] Fix test script introduced by API change (#21347)
Browse files Browse the repository at this point in the history
* fix

* fix test failure

* Update release/nightly_tests/dataset/ray_sgd_training.py

Co-authored-by: matthewdeng <[email protected]>
  • Loading branch information
scv119 and matthewdeng authored Jan 3, 2022
1 parent 4581baa commit 704404d
Showing 1 changed file with 4 additions and 28 deletions.
32 changes: 4 additions & 28 deletions release/nightly_tests/dataset/ray_sgd_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,8 @@
from ray import train
from ray.data.aggregate import Mean, Std
from ray.data.dataset_pipeline import DatasetPipeline
from ray.train import Trainer, TrainingCallback
from ray.train.callbacks import TBXLoggerCallback


# TODO(amogkam): Upstream this into Ray Train.
class MLflowCallback(TrainingCallback):
def __init__(self, config):
self.config = config

def handle_result(self, results, **info):
# For each result that's being reported by ``train.report()``,
# we get the result from the rank 0 worker (i.e. first worker) and
# report it to MLflow.
rank_zero_results = results[0]
mlflow.log_metrics(rank_zero_results)

# TODO: fix type hint for logdir
def start_training(self, logdir, **info):
mlflow.start_run(run_name=str(logdir.name))
mlflow.log_params(config)

# TODO: Update TrainCallback to provide logdir in finish_training.
self.logdir = logdir

def finish_training(self, error: bool = False, **info):
# Save the Trainer checkpoints as artifacts to mlflow.
mlflow.log_artifacts(self.logdir)
from ray.train import Trainer
from ray.train.callbacks import MLflowLoggerCallback, TBXLoggerCallback


def read_dataset(path: str) -> ray.data.Dataset:
Expand Down Expand Up @@ -593,7 +568,8 @@ def train(self):
os.makedirs(tbx_runs_dir, exist_ok=True)
callbacks = [
TBXLoggerCallback(logdir=tbx_runs_dir),
MLflowCallback(config)
MLflowLoggerCallback(
experiment_name="cuj-big-data-training", save_artifact=True)
]

# Remove CPU resource so Datasets can be scheduled.
Expand Down

0 comments on commit 704404d

Please sign in to comment.