From bf35e3cc69e88bc14d886b73cdfd3ad6cb1ba3dc Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 13 Sep 2022 13:58:49 -0400 Subject: [PATCH] Revert "[train/horovod] Fix horovod long running release test (#27179)" (#28476) This reverts commit 4b59dfbe59a143ab8dcc505dad860b4c330b6426. Looks like this breaks linux://python/ray/air:horovod_cifar_pbt_example Signed-off-by: Stephanie Wang swang@cs.berkeley.edu --- .../horovod/workloads/horovod_tune_test.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/release/air_tests/horovod/workloads/horovod_tune_test.py b/release/air_tests/horovod/workloads/horovod_tune_test.py index f6906d6e6260..faa75c2ee1a3 100755 --- a/release/air_tests/horovod/workloads/horovod_tune_test.py +++ b/release/air_tests/horovod/workloads/horovod_tune_test.py @@ -4,7 +4,7 @@ import torchvision from ray.air import RunConfig, session from ray.train.horovod import HorovodTrainer -from ray.air.config import ScalingConfig, FailureConfig, CheckpointConfig +from ray.air.config import ScalingConfig from ray.tune.tune_config import TuneConfig from ray.tune.tuner import Tuner from torch.utils.data import DataLoader @@ -40,10 +40,9 @@ def train_loop_per_worker(config): checkpoint = session.get_checkpoint() if checkpoint: - checkpoint_dict = checkpoint.to_dict() - model_state = checkpoint_dict["model_state"] - optimizer_state = checkpoint_dict["optimizer_state"] - epoch = checkpoint_dict["epoch"] + model_state = checkpoint["model_state"] + optimizer_state = checkpoint["optimizer_state"] + epoch = checkpoint["epoch"] net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) @@ -60,6 +59,7 @@ def train_loop_per_worker(config): trainloader = DataLoader( trainset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=4 ) + trainloader_len = len(trainloader) for epoch in range(epoch, 40): # loop over the dataset multiple times running_loss = 0.0 @@ -81,22 +81,23 @@ def train_loop_per_worker(config): # print statistics running_loss += loss.item() epoch_steps += 1 - + if i == trainloader_len - 1: + checkpoint = Checkpoint.from_dict( + dict( + model_state=net.state_dict(), + optimizer_state=optimizer.state_dict(), + epoch=epoch, + ) + ) + else: + checkpoint = None + session.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint) if i % 2000 == 1999: # print every 2000 mini-batches print( "[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps) ) - checkpoint = Checkpoint.from_dict( - dict( - model_state=net.state_dict(), - optimizer_state=optimizer.state_dict(), - epoch=epoch, - ) - ) - session.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint) - if __name__ == "__main__": import argparse @@ -160,10 +161,9 @@ def train_loop_per_worker(config): ), run_config=RunConfig( stop={"training_iteration": 1} if args.smoke_test else None, - failure_config=FailureConfig(fail_fast=False), - checkpoint_config=CheckpointConfig(num_to_keep=1), callbacks=[ProgressCallback()], ), + _tuner_kwargs={"fail_fast": False, "keep_checkpoints_num": 1}, ) result_grid = tuner.fit()