Skip to content

Commit

Permalink
Revert "[train/horovod] Fix horovod long running release test (#27179)…
Browse files Browse the repository at this point in the history
…" (#28476)

This reverts commit 4b59dfb.

Looks like this breaks linux://python/ray/air:horovod_cifar_pbt_example

Signed-off-by: Stephanie Wang [email protected]
  • Loading branch information
stephanie-wang authored Sep 13, 2022
1 parent 63ce640 commit bf35e3c
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions release/air_tests/horovod/workloads/horovod_tune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchvision
from ray.air import RunConfig, session
from ray.train.horovod import HorovodTrainer
from ray.air.config import ScalingConfig, FailureConfig, CheckpointConfig
from ray.air.config import ScalingConfig
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -40,10 +40,9 @@ def train_loop_per_worker(config):

checkpoint = session.get_checkpoint()
if checkpoint:
checkpoint_dict = checkpoint.to_dict()
model_state = checkpoint_dict["model_state"]
optimizer_state = checkpoint_dict["optimizer_state"]
epoch = checkpoint_dict["epoch"]
model_state = checkpoint["model_state"]
optimizer_state = checkpoint["optimizer_state"]
epoch = checkpoint["epoch"]

net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
Expand All @@ -60,6 +59,7 @@ def train_loop_per_worker(config):
trainloader = DataLoader(
trainset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=4
)
trainloader_len = len(trainloader)

for epoch in range(epoch, 40): # loop over the dataset multiple times
running_loss = 0.0
Expand All @@ -81,22 +81,23 @@ def train_loop_per_worker(config):
# print statistics
running_loss += loss.item()
epoch_steps += 1

if i == trainloader_len - 1:
checkpoint = Checkpoint.from_dict(
dict(
model_state=net.state_dict(),
optimizer_state=optimizer.state_dict(),
epoch=epoch,
)
)
else:
checkpoint = None
session.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint)
if i % 2000 == 1999: # print every 2000 mini-batches
print(
"[%d, %5d] loss: %.3f"
% (epoch + 1, i + 1, running_loss / epoch_steps)
)

checkpoint = Checkpoint.from_dict(
dict(
model_state=net.state_dict(),
optimizer_state=optimizer.state_dict(),
epoch=epoch,
)
)
session.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint)


if __name__ == "__main__":
import argparse
Expand Down Expand Up @@ -160,10 +161,9 @@ def train_loop_per_worker(config):
),
run_config=RunConfig(
stop={"training_iteration": 1} if args.smoke_test else None,
failure_config=FailureConfig(fail_fast=False),
checkpoint_config=CheckpointConfig(num_to_keep=1),
callbacks=[ProgressCallback()],
),
_tuner_kwargs={"fail_fast": False, "keep_checkpoints_num": 1},
)

result_grid = tuner.fit()
Expand Down

0 comments on commit bf35e3c

Please sign in to comment.