Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train+data] Remove preprocessor argument from trainers #43146

Merged
merged 9 commits into from
Feb 14, 2024
6 changes: 0 additions & 6 deletions python/ray/air/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import ray
from ray.train import Checkpoint, CheckpointConfig, ScalingConfig
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.data.preprocessor import Preprocessor
from ray.train.trainer import BaseTrainer


Expand Down Expand Up @@ -201,11 +200,6 @@ def test_datasets():
DummyTrainer(datasets={"test": DummyDataset()})


def test_preprocessor_deprecated():
with pytest.raises(DeprecationWarning):
DummyTrainer(preprocessor=Preprocessor())


def test_resume_from_checkpoint(tmpdir):
with pytest.raises(ValueError):
DummyTrainer(resume_from_checkpoint="invalid")
Expand Down
10 changes: 0 additions & 10 deletions python/ray/train/_internal/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ray.actor import ActorHandle
from ray.data import DataIterator, Dataset, ExecutionOptions, NodeIdStr
from ray.data._internal.execution.interfaces.execution_options import ExecutionResources
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import DeveloperAPI, PublicAPI


Expand Down Expand Up @@ -132,12 +131,3 @@ def default_ingest_options() -> ExecutionOptions:
preserve_order=ctx.execution_options.preserve_order,
verbose_progress=ctx.execution_options.verbose_progress,
)

def _legacy_preprocessing(
self, datasets: Dict[str, Dataset], preprocessor: Optional[Preprocessor]
) -> Dict[str, Dataset]:
"""Legacy hook for backwards compatiblity.

This will be removed in the future.
"""
return datasets # No-op for non-legacy configs.
75 changes: 12 additions & 63 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from ray.train import Checkpoint
from ray.train._internal.session import _get_session
from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
from ray.train.constants import TRAIN_DATASET_KEY
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
from ray.data import Dataset
from ray.data.preprocessor import Preprocessor
from ray.tune import Trainable

_TRAINER_PKL = "trainer.pkl"
Expand Down Expand Up @@ -89,9 +87,7 @@ class BaseTrainer(abc.ABC):
called in sequence on the remote actor.
- ``trainer.setup()``: Any heavyweight Trainer setup should be
specified here.
- ``trainer.preprocess_datasets()``: The datasets passed to the Trainer will be
setup here.
- ``trainer.train_loop()``: Executes the main training logic.
- ``trainer.training_loop()``: Executes the main training logic.
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
object where you can access metrics from your training run, as well
as any checkpoints that may have been saved.
Expand Down Expand Up @@ -191,16 +187,13 @@ def __init__(
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
):
self.scaling_config = (
scaling_config if scaling_config is not None else ScalingConfig()
)
self.run_config = run_config if run_config is not None else RunConfig()
self.metadata = metadata
self.datasets = datasets if datasets is not None else {}
self.preprocessor = preprocessor
self.starting_checkpoint = resume_from_checkpoint

# These attributes should only be set through `BaseTrainer.restore`
Expand All @@ -211,17 +204,13 @@ def __init__(

air_usage.tag_air_trainer(self)

if preprocessor is not None:
raise DeprecationWarning(PREPROCESSOR_DEPRECATION_MESSAGE)

@PublicAPI(stability="alpha")
@classmethod
def restore(
cls: Type["BaseTrainer"],
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "BaseTrainer":
Expand Down Expand Up @@ -351,10 +340,6 @@ def training_loop(self):
)
param_dict["datasets"] = datasets

# If no preprocessor is re-specified, then it will be set to None
# here and loaded from the latest checkpoint
param_dict["preprocessor"] = preprocessor

if scaling_config:
param_dict["scaling_config"] = scaling_config

Expand Down Expand Up @@ -470,15 +455,6 @@ def _validate_attributes(self):
f"{self.metadata}: {e}"
)

# Preprocessor
if self.preprocessor is not None and not isinstance(
self.preprocessor, ray.data.Preprocessor
):
raise ValueError(
f"`preprocessor` should be an instance of `ray.data.Preprocessor`, "
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
)

if self.starting_checkpoint is not None and not isinstance(
self.starting_checkpoint, Checkpoint
):
Expand Down Expand Up @@ -511,50 +487,19 @@ def setup(self) -> None:
pass

def preprocess_datasets(self) -> None:
"""Called during fit() to preprocess dataset attributes with preprocessor.

.. note:: This method is run on a remote process.

This method is called prior to entering the training_loop.

If the ``Trainer`` has both a datasets dict and
a preprocessor, the datasets dict contains a training dataset (denoted by
the "train" key), and the preprocessor has not yet
been fit, then it will be fit on the train dataset.

Then, all Trainer's datasets will be transformed by the preprocessor.

The transformed datasets will be set back in the ``self.datasets`` attribute
of the Trainer to be used when overriding ``training_loop``.
"""
# Evaluate all datasets.
self.datasets = {k: d() if callable(d) else d for k, d in self.datasets.items()}

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset and self.preprocessor.fit_status() in (
ray.data.Preprocessor.FitStatus.NOT_FITTED,
ray.data.Preprocessor.FitStatus.PARTIALLY_FITTED,
):
self.preprocessor.fit(train_dataset)

# Execute dataset transformations serially for now.
# Cannot execute them in remote tasks due to dataset ownership model:
# if datasets are created on a remote node, then if that node fails,
# we cannot recover the dataset.
new_datasets = {}
for key, dataset in self.datasets.items():
new_datasets[key] = self.preprocessor.transform(dataset)

self.datasets = new_datasets
"""Deprecated."""
raise DeprecationWarning(
"`preprocess_datasets` is no longer used, since preprocessors "
f"are no longer accepted by Trainers.\n{PREPROCESSOR_DEPRECATION_MESSAGE}"
)

@abc.abstractmethod
def training_loop(self) -> None:
"""Loop called by fit() to run training and report results to Tune.

.. note:: This method runs on a remote process.

``self.datasets`` have already been preprocessed by ``self.preprocessor``.
``self.datasets`` have already been evaluated if they were wrapped in a factory.

You can use the :ref:`Ray Train utilities <train-loop-api>`
(:func:`train.report() <ray.train.report>` and
Expand Down Expand Up @@ -729,8 +674,12 @@ def train_func(config):
# else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

# Evaluate datasets if they are wrapped in a factory.
trainer.datasets = {
k: d() if callable(d) else d for k, d in self.datasets.items()
}

trainer.setup()
trainer.preprocess_datasets()
trainer.training_loop()
Comment on lines +677 to 683
Copy link
Contributor Author

@justinvyu justinvyu Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preprocess_datasets used to do this "evaluation" of the dataset factory function, plus any dataset preprocessing from Trainer.preprocessor. This interface is not needed so I removed it, and brought the evaluation logic out.


# Change the name of the training function to match the name of the Trainer
Expand Down
26 changes: 7 additions & 19 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import ray
from ray._private.thirdparty.tabulate.tabulate import tabulate
Expand All @@ -15,9 +15,6 @@
from ray.widgets import Template
from ray.widgets.util import repr_with_fallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -191,11 +188,12 @@ def __init__(self, train_loop_per_worker, my_backend_config:
dataset_config: Configuration for dataset ingest. This is merged with the
default dataset config for the given trainer (`cls._dataset_config`).
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use
the key "train" to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
datasets: Ray Datasets to use for training and evaluation.
This is a dict where the key is the name of the dataset, which
can be accessed from within the ``train_loop_per_worker`` by calling
``train.get_dataset_shard(dataset_key)``.
By default, all datasets are sharded equally across workers.
This can be configured via ``dataset_config``.
metadata: Dict that should be made available via
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
Expand Down Expand Up @@ -233,8 +231,6 @@ def __init__(
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
):
self._train_loop_per_worker = train_loop_per_worker
self._train_loop_config = train_loop_config
Expand All @@ -259,7 +255,6 @@ def __init__(
run_config=run_config,
datasets=datasets,
metadata=metadata,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)

Expand Down Expand Up @@ -314,13 +309,6 @@ def _validate_attributes(self):
self._train_loop_per_worker, "train_loop_per_worker"
)

def preprocess_datasets(self) -> None:
# Evaluate all datasets.
self.datasets = {k: d() if callable(d) else d for k, d in self.datasets.items()}
self.datasets = self._data_config._legacy_preprocessing(
self.datasets, self.preprocessor
)

def _validate_train_loop_per_worker(
self, train_loop_per_worker: Callable, fn_name: str
) -> None:
Expand Down
14 changes: 11 additions & 3 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from packaging.version import Version

from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.trainer import BaseTrainer, GenDataset
Expand All @@ -13,7 +15,6 @@
if TYPE_CHECKING:
import xgboost_ray

from ray.data.preprocessor import Preprocessor

_WARN_REPARTITION_THRESHOLD = 10 * 1024**3
_DEFAULT_NUM_ITERATIONS = 10
Expand Down Expand Up @@ -149,7 +150,6 @@ def __init__(
num_boost_round: int = _DEFAULT_NUM_ITERATIONS,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None, # Deprecated
resume_from_checkpoint: Optional[Checkpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
Expand All @@ -165,7 +165,6 @@ def __init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
metadata=metadata,
)
Expand Down Expand Up @@ -276,6 +275,15 @@ def _repartition_datasets_to_match_num_actors(self):
self._ray_params.num_actors
)

def setup(self) -> None:
import xgboost_ray

# XGBoost/LightGBM-Ray requires each dataset to have at least as many
# blocks as there are workers.
# This is only applicable for xgboost-ray<0.1.16
if Version(xgboost_ray.__version__) < Version("0.1.16"):
self._repartition_datasets_to_match_num_actors()

def training_loop(self) -> None:
config = self.train_kwargs.copy()
config[self._num_iterations_argument] = self.num_boost_round
Expand Down
8 changes: 1 addition & 7 deletions python/ray/train/horovod/horovod_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from ray.air.config import RunConfig, ScalingConfig
from ray.train import Checkpoint, DataConfig
Expand All @@ -7,9 +7,6 @@
from ray.train.trainer import GenDataset
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class HorovodTrainer(DataParallelTrainer):
Expand Down Expand Up @@ -191,8 +188,6 @@ def __init__(
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
):
super().__init__(
train_loop_per_worker=train_loop_per_worker,
Expand All @@ -202,7 +197,6 @@ def __init__(
dataset_config=dataset_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
metadata=metadata,
)
16 changes: 0 additions & 16 deletions python/ray/train/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
from ray.train.lightgbm import RayTrainReportCallback
from ray.util.annotations import PublicAPI

try:
from packaging.version import Version
except ImportError:
from distutils.version import LooseVersion as Version


@PublicAPI(stability="beta")
class LightGBMTrainer(GBDTTrainer):
Expand Down Expand Up @@ -123,14 +118,3 @@ def _model_iteration(
if isinstance(model, lightgbm.Booster):
return model.current_iteration()
return model.booster_.current_iteration()

def preprocess_datasets(self) -> None:
super().preprocess_datasets()

# XGBoost/LightGBM-Ray requires each dataset to have at least as many
# blocks as there are workers.
# This is only applicable for xgboost-ray<0.1.16
import xgboost_ray

if Version(xgboost_ray.__version__) < Version("0.1.16"):
self._repartition_datasets_to_match_num_actors()
Loading
Loading