Skip to content

Commit

Permalink
[tune] Update Lightning examples to support PTL 1.5 (ray-project#20562)
Browse files Browse the repository at this point in the history
To helps resolve the issues users are facing with running Lightning examples with Ray Tune Lightning-AI/pytorch-lightning#10407

Co-authored-by: Amog Kamsetty <[email protected]>
  • Loading branch information
2 people authored and simonsays1980 committed Feb 27, 2022
1 parent f0050a0 commit d4d13bd
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 12 deletions.
5 changes: 3 additions & 2 deletions python/ray/tune/examples/mnist_ptl_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from filelock import FileLock
from torch.nn import functional as F
from torchmetrics import Accuracy
import pytorch_lightning as pl
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
import os
Expand All @@ -24,7 +25,7 @@ def __init__(self, config, data_dir=None):
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = pl.metrics.Accuracy()
self.accuracy = Accuracy()

def forward(self, x):
batch_size, channels, width, height = x.size()
Expand Down Expand Up @@ -75,7 +76,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0):
max_epochs=num_epochs,
# If fractional GPUs passed in, convert to int.
gpus=math.ceil(num_gpus),
progress_bar_refresh_rate=0,
enable_progress_bar=False,
callbacks=[TuneReportCallback(metrics, on="validation_end")],
)
trainer.fit(model, dm)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/examples/mnist_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def configure_optimizers(self):

def train_mnist(config):
model = LightningMNISTClassifier(config)
trainer = pl.Trainer(max_epochs=10, show_progress_bar=False)
trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False)

trainer.fit(model)
# __lightning_end__
Expand All @@ -148,7 +148,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
gpus=math.ceil(num_gpus),
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
enable_progress_bar=False,
callbacks=[
TuneReportCallback(
{
Expand All @@ -174,7 +174,7 @@ def train_mnist_tune_checkpoint(config,
"gpus": math.ceil(num_gpus),
"logger": TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
"progress_bar_refresh_rate": 0,
"enable_progress_bar": False,
"callbacks": [
TuneReportCheckpointCallback(
metrics={
Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/ray_lightning/simple_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytorch_lightning as pl

from ray.util.ray_lightning import RayPlugin
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_ddp_resources
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_resources

num_cpus_per_actor = 1
num_workers = 1
Expand Down Expand Up @@ -70,7 +70,7 @@ def main():
num_samples=1,
metric="loss",
mode="min",
resources_per_trial=get_tune_ddp_resources(
resources_per_trial=get_tune_resources(
num_workers=num_workers, cpus_per_worker=num_cpus_per_actor
),
)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/util/ray_lightning/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

TuneReportCallback = None
TuneReportCheckpointCallback = None
get_tune_ddp_resources = None
get_tune_resources = None

try:
from ray_lightning.tune import (
TuneReportCallback,
TuneReportCheckpointCallback,
get_tune_ddp_resources,
get_tune_resources,
)
except ImportError:
logger.info(
Expand All @@ -22,5 +22,5 @@
__all__ = [
"TuneReportCallback",
"TuneReportCheckpointCallback",
"get_tune_ddp_resources",
"get_tune_resources",
]
2 changes: 1 addition & 1 deletion python/requirements/ml/requirements_tune.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ nevergrad==0.4.3.post7
optuna==2.9.1
pytest-remotedata==0.3.2
lightning-bolts==0.4.0
pytorch-lightning==1.4.9
pytorch-lightning==1.5.10
shortuuid==1.0.1
scikit-learn==0.24.2
scikit-optimize==0.8.1
Expand Down
2 changes: 1 addition & 1 deletion python/requirements/ml/requirements_upstream.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Because they depend on Ray, we can't pin the subdependencies.
# So we separate its own requirements file.

ray_lightning==0.1.1
ray_lightning==0.2.0
tune-sklearn==0.4.1
xgboost_ray==0.1.4
lightgbm_ray==0.0.2
Expand Down

0 comments on commit d4d13bd

Please sign in to comment.