Skip to content

Commit

Permalink
[train] fix tensorflow example by using ScalingConfig (#46565)
Browse files Browse the repository at this point in the history
Signed-off-by: MJ <[email protected]>
  • Loading branch information
mjovanovic9999 authored Jul 14, 2024
1 parent 34d5be8 commit 4e75921
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.air.integrations.keras import ReportCheckpointCallback
from ray.data.datasource import SimpleTensorFlowDatasource
from ray.data.extensions import TensorArray
from ray.train import Result
from ray.train import Result, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard


Expand Down Expand Up @@ -121,7 +121,7 @@ def train_tensorflow_mnist(
) -> Result:
train_dataset = get_dataset(split_type="train")
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ray
from ray import tune
from ray.train import ScalingConfig
from ray.train.examples.tf.tensorflow_mnist_example import train_func
from ray.train.tensorflow import TensorflowTrainer
from ray.tune.tune_config import TuneConfig
Expand All @@ -11,7 +12,7 @@
def tune_tensorflow_mnist(
num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
):
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
Expand Down

0 comments on commit 4e75921

Please sign in to comment.