Skip to content

Commit

Permalink
[CI/air] Fix lightning_gpu_tune_.* release test (ray-project#35193)
Browse files Browse the repository at this point in the history
Temporarily fix the release tests fails described in ray-project#35187. TODO: Come up with a holistic solution for metric dict flattening.



Signed-off-by: woshiyyya <[email protected]>
  • Loading branch information
woshiyyya authored and architkulkarni committed May 16, 2023
1 parent c8eb82f commit b16eec4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
7 changes: 5 additions & 2 deletions release/lightning_tests/workloads/lightning_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def validation_step(self, val_batch, batch_idx):
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss, sync_dist=True)
self.log("ptl/val_accuracy", avg_acc, sync_dist=True)

# TODO(yunxuanx): change this back to ptl/val_loss after
# we resolved the metric unpacking issue
self.log("val_loss", avg_loss, sync_dist=True)
self.log("val_accuracy", avg_acc, sync_dist=True)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
Expand Down
4 changes: 2 additions & 2 deletions release/lightning_tests/workloads/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger=CSVLogger("logs", name="my_exp_name"),
)
.fit_params(datamodule=MNISTDataModule(batch_size=128))
.checkpointing(monitor="ptl/val_accuracy", mode="max", save_last=True)
.checkpointing(monitor="val_accuracy", mode="max", save_last=True)
.build()
)

Expand All @@ -41,7 +41,7 @@
taken = time.time() - start
result = {
"time_taken": taken,
"ptl/val_accuracy": result.metrics["ptl/val_accuracy"],
"val_accuracy": result.metrics["val_accuracy"],
}
test_output_json = os.environ.get(
"TEST_OUTPUT_JSON", "/tmp/lightning_trainer_test.json"
Expand Down
10 changes: 5 additions & 5 deletions release/lightning_tests/workloads/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
logger=CSVLogger("logs", name="my_exp_name"),
)
.fit_params(datamodule=MNISTDataModule(batch_size=200))
.checkpointing(monitor="ptl/val_accuracy", mode="max")
.checkpointing(monitor="val_accuracy", mode="max")
.build()
)

Expand Down Expand Up @@ -57,12 +57,12 @@
verbose=2,
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="ptl/val_accuracy",
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
),
tune_config=tune.TuneConfig(
metric="ptl/val_accuracy",
metric="val_accuracy",
mode="max",
num_samples=2,
scheduler=PopulationBasedTraining(
Expand All @@ -73,7 +73,7 @@
),
)
results = tuner.fit()
best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
best_result = results.get_best_result(metric="val_accuracy", mode="max")
best_result

assert len(results.errors) == 0
Expand All @@ -83,7 +83,7 @@
# Report experiment results
result = {
"time_taken": taken,
"ptl/val_accuracy": best_result.metrics["ptl/val_accuracy"],
"val_accuracy": best_result.metrics["val_accuracy"],
}

test_output_json = os.environ.get(
Expand Down

0 comments on commit b16eec4

Please sign in to comment.