Deprecate the preprocessor and object store memory arguments for ray.train.Trainer
.
This doc proposes to simplify Train's DatasetConfig as we move to the new streaming backend by default for Datasets. Similar to as noted in #25, Ray Datasets will have both lazy and streaming execution by default in Ray 2.4. Furthermore, DatasetPipeline
will be deprecated in the future to consolidate functionality on the Dataset class.
With these changes, a few possibilities for simplification open up in the Train API:
- Decoupling Preprocessors from Trainers, so that Data preprocessing is performed on the Dataset explicitly by the user prior to passing the Dataset into the Trainer.
- Using the same resource limiting API in Train as in Datasets (i.e.,
ExecutionResources
), instead of having a separatemax_object_store_memory_fraction
config.
Simplification is greatly desirable here, since users commonly find Dataset<>Trainer interactions difficult to understand and debug.
main ray
project. Changes are made to Ray Data and Ray AIR level.
The proposal will be open to the public, but please suggest a few experienced Ray contributors in this technical domain whose comments will help this proposal. Ideally, the list should include Ray committers.
@amogkam, @c21, @clarkzinzow, @jianoiax
To make the review process more productive, the owner of each proposal should identify a shepherd (should be a senior Ray committer). The shepherd is responsible for working with the owner and making sure the proposal is in good shape (with necessary information) before marking it as ready for broader review.
@stephanie-wang
-
Introduce a
resource_limits: ExecutionResources(object_store_memory=2 * GiB)
arg toray.air.DatasetConfig
. This enables streaming by default, with a limit of 2GiB, and deprecates the previousmax_object_store_memory_fraction
argument. -
Introduce
Dataset.get_logical_plan()
(DeveloperAPI), which returns the logical plan that can be used to extract the lineage of preprocessors applied to this Dataset. If multiple preprocessors are applied, Train can return aChain
of the preprocessors. Non-preprocessor operations on the Dataset are ignored, and we can also allow ignoring preprocessors such as per epoch preprocessors. This function will be used by Train to persist fitted preprocessors with checkpoints. -
Deprecate the following additional Trainer configs when streaming is enabled:
global_shuffle
andrandomize_block_order
(user to use native Dataset shuffle ops), and the preprocessing argsfit
,transform
,preprocessor
, andper_epoch_preprocessor
(user to setup preprocessing explicitly prior to creating Trainer).
base = ray.data.read_parquet("s3://bucket/etl_output")
fact_table = ray.data.read_csv("s3://bucket/my.csv")
# Create the preprocessor.
prep = StandardScaler(["f1", "f2"])
# Create the per-epoch preprocessor.
per_epoch_prep = RandomNoisePreprocessor()
# Trainer applies preprocessing internally via config.
trainer = TorchTrainer(
model,
datasets={
"train_ds": train_ds,
"fact_table": fact_table,
},
scaling_config=ScalingConfig(num_workers=4),
preprocessor=prep,
dataset_config={
"train_ds": {
"max_object_store_memory_fraction": 0.2, # Enable streaming.
"randomize_block_order": True,
"per_epoch_preprocessor": per_epoch_prep,
},
},
)
# Checkpoint includes fitted preprocessor.
best_checkpoint = trainer.fit().checkpoint
assert best_checkpoint.get_preprocessor() == prep
base = ray.data.read_parquet("s3://bucket/etl_output")
fact_table = ray.data.read_csv("s3://bucket/my.csv")
# Fit the preprocessor.
prep = StandardScaler(["f1", "f2"])
prep.fit(base)
# Apply base preprocessing.
train_ds = base.map_batches(prep)
train_ds.cache() # Optional: cache the base data in memory.
# Per-epoch preprocessing.
per_epoch_prep = RandomNoisePreprocessor()
per_epoch_prep.ignore_for_inference = True
train_ds = train_ds \
.randomize_block_order() \
.map_batches(per_epoch_prep)
# Trainer doesn't know about preprocessing at all.
trainer = TorchTrainer(
model,
datasets={
"train_ds": train_ds,
"fact_table": fact_table,
},
scaling_config=ScalingConfig(num_workers=4),
dataset_config={
"train_ds": {
"resource_limits": ExecutionResources(
object_store_memory=20e9, # Customized streaming memory limit.
),
},
},
)
# Checkpoint includes fitted preprocessor.
best_checkpoint = trainer.fit().checkpoint
assert best_checkpoint.get_preprocessor() == prep
While the "after" code is longer, note that all the data processing code is now cleanly separated from the Trainer, which both a conceptual and practical simplification. In addition, having the fitted preprocessor computed early enables the user code to reference it (e.g., to get computed categories, etc.).
-
Q: What if I wanted to change per-trial datasets / prep with Tune?
-
A: You could prepare multiple datasets lazily on the driver.
-
Q: Are we deprecating the preprocessor arg for Train entirely?
-
A: Yes.
-
Q: Will we still save the preprocessor in the Checkpoint?
-
A: Yes, this doesn't change.
-
Q: Should we have both
Preprocessor.transform
andDataset.map_batches
? -
A: We will deprecate the former.
-
Q: What happens if you apply multiple preprocessors to a Datasets?
-
A: The checkpoint will have the full chain, including per-epoch ones. Preprocessors can be tagged as used for training only / ignored during inference by setting an
ignore_for_inference
(constructor) attribute. -
Q: What happens if you apply ordinary functions to the Dataset?
-
A: You'll get a warning that these functions are not captured in the preprocessing chain, and to use BatchMapper if you want that.
-
Q: Why not require the user to do all Data operations outside of Train, including the split?
-
A: This would break tuning, as Train needs to create a separate Data stream per trial. This is not possible post split as calling split is a consumption operation.
An important part of the proposal is to explicitly point out any compability implications of the proposed change. If there is any, we should thouroughly discuss a plan to deprecate existing APIs and migration to the new one(s).
Ray 2.4: Lay the groundwork for these new APIs
- Streaming on by default in Datasets only (not Train).
- API changes from the related inference REP #25
Ray 2.5: Onboard new users onto new APIs
- Introduce the API changes proposed above, and enable streaming by default in Train.
- Deprecated APIs will be inaccessible in streaming mode for Train.
- Legacy APIs will be fully supported in non-streaming mode for Train.
- Rewrite docs and examples to use new APIs.
Ray 2.6/7: Deprecate old APIs
- Full feature parity with global / windowed shuffles using new streaming data APIs.
- Fully deprecate DatasetPipeline / legacy Train APIs.
The proposal should discuss how the change will be tested before it can be merged or enabled. It should also include other acceptance criteria including documentation and examples.
- Unit and integration for new APIs
- Documentation and examples on new API.