Skip to content

Latest commit

 

History

History
167 lines (122 loc) · 7.58 KB

2023-03-15-train-api.md

File metadata and controls

167 lines (122 loc) · 7.58 KB

Simplifying Ray Train Ingest APIs

Summary

Deprecate the preprocessor and object store memory arguments for ray.train.Trainer.

General Motivation

This doc proposes to simplify Train's DatasetConfig as we move to the new streaming backend by default for Datasets. Similar to as noted in #25, Ray Datasets will have both lazy and streaming execution by default in Ray 2.4. Furthermore, DatasetPipeline will be deprecated in the future to consolidate functionality on the Dataset class.

With these changes, a few possibilities for simplification open up in the Train API:

  1. Decoupling Preprocessors from Trainers, so that Data preprocessing is performed on the Dataset explicitly by the user prior to passing the Dataset into the Trainer.
  2. Using the same resource limiting API in Train as in Datasets (i.e., ExecutionResources), instead of having a separate max_object_store_memory_fraction config.

Simplification is greatly desirable here, since users commonly find Dataset<>Trainer interactions difficult to understand and debug.

Should this change be within ray or outside?

main ray project. Changes are made to Ray Data and Ray AIR level.

Stewardship

Required Reviewers

The proposal will be open to the public, but please suggest a few experienced Ray contributors in this technical domain whose comments will help this proposal. Ideally, the list should include Ray committers.

@amogkam, @c21, @clarkzinzow, @jianoiax

Shepherd of the Proposal (should be a senior committer)

To make the review process more productive, the owner of each proposal should identify a shepherd (should be a senior Ray committer). The shepherd is responsible for working with the owner and making sure the proposal is in good shape (with necessary information) before marking it as ready for broader review.

@stephanie-wang

Design and Architecture

API Changes

  1. Introduce a resource_limits: ExecutionResources(object_store_memory=2 * GiB) arg to ray.air.DatasetConfig. This enables streaming by default, with a limit of 2GiB, and deprecates the previous max_object_store_memory_fraction argument.

  2. Introduce Dataset.get_logical_plan() (DeveloperAPI), which returns the logical plan that can be used to extract the lineage of preprocessors applied to this Dataset. If multiple preprocessors are applied, Train can return a Chain of the preprocessors. Non-preprocessor operations on the Dataset are ignored, and we can also allow ignoring preprocessors such as per epoch preprocessors. This function will be used by Train to persist fitted preprocessors with checkpoints.

  3. Deprecate the following additional Trainer configs when streaming is enabled: global_shuffle and randomize_block_order (user to use native Dataset shuffle ops), and the preprocessing args fit, transform, preprocessor, and per_epoch_preprocessor (user to setup preprocessing explicitly prior to creating Trainer).

Example:

Before

base = ray.data.read_parquet("s3://bucket/etl_output")
fact_table = ray.data.read_csv("s3://bucket/my.csv")

# Create the preprocessor.
prep = StandardScaler(["f1", "f2"])

# Create the per-epoch preprocessor.
per_epoch_prep = RandomNoisePreprocessor()

# Trainer applies preprocessing internally via config.
trainer = TorchTrainer(
    model,
    datasets={
        "train_ds": train_ds,
        "fact_table": fact_table,
    },
    scaling_config=ScalingConfig(num_workers=4),
    preprocessor=prep,
    dataset_config={
        "train_ds": {
            "max_object_store_memory_fraction": 0.2,  # Enable streaming.
            "randomize_block_order": True,
            "per_epoch_preprocessor": per_epoch_prep,
        },
    },
)

# Checkpoint includes fitted preprocessor.
best_checkpoint = trainer.fit().checkpoint
assert best_checkpoint.get_preprocessor() == prep

After

base = ray.data.read_parquet("s3://bucket/etl_output")
fact_table = ray.data.read_csv("s3://bucket/my.csv")

# Fit the preprocessor.
prep = StandardScaler(["f1", "f2"])
prep.fit(base)

# Apply base preprocessing.
train_ds = base.map_batches(prep)
train_ds.cache()  # Optional: cache the base data in memory.

# Per-epoch preprocessing.
per_epoch_prep = RandomNoisePreprocessor()
per_epoch_prep.ignore_for_inference = True
train_ds = train_ds \
    .randomize_block_order() \
    .map_batches(per_epoch_prep)

# Trainer doesn't know about preprocessing at all.
trainer = TorchTrainer(
    model,
    datasets={
        "train_ds": train_ds,
        "fact_table": fact_table,
    },
    scaling_config=ScalingConfig(num_workers=4),
    dataset_config={
        "train_ds": {
            "resource_limits": ExecutionResources(
                object_store_memory=20e9,  # Customized streaming memory limit.
            ),
        },
    },
)

# Checkpoint includes fitted preprocessor.
best_checkpoint = trainer.fit().checkpoint
assert best_checkpoint.get_preprocessor() == prep

While the "after" code is longer, note that all the data processing code is now cleanly separated from the Trainer, which both a conceptual and practical simplification. In addition, having the fitted preprocessor computed early enables the user code to reference it (e.g., to get computed categories, etc.).

FAQ

  • Q: What if I wanted to change per-trial datasets / prep with Tune?

  • A: You could prepare multiple datasets lazily on the driver.

  • Q: Are we deprecating the preprocessor arg for Train entirely?

  • A: Yes.

  • Q: Will we still save the preprocessor in the Checkpoint?

  • A: Yes, this doesn't change.

  • Q: Should we have both Preprocessor.transform and Dataset.map_batches?

  • A: We will deprecate the former.

  • Q: What happens if you apply multiple preprocessors to a Datasets?

  • A: The checkpoint will have the full chain, including per-epoch ones. Preprocessors can be tagged as used for training only / ignored during inference by setting an ignore_for_inference (constructor) attribute.

  • Q: What happens if you apply ordinary functions to the Dataset?

  • A: You'll get a warning that these functions are not captured in the preprocessing chain, and to use BatchMapper if you want that.

  • Q: Why not require the user to do all Data operations outside of Train, including the split?

  • A: This would break tuning, as Train needs to create a separate Data stream per trial. This is not possible post split as calling split is a consumption operation.

Compatibility, Deprecation, and Migration Plan

An important part of the proposal is to explicitly point out any compability implications of the proposed change. If there is any, we should thouroughly discuss a plan to deprecate existing APIs and migration to the new one(s).

Ray 2.4: Lay the groundwork for these new APIs

  • Streaming on by default in Datasets only (not Train).
  • API changes from the related inference REP #25

Ray 2.5: Onboard new users onto new APIs

  • Introduce the API changes proposed above, and enable streaming by default in Train.
  • Deprecated APIs will be inaccessible in streaming mode for Train.
  • Legacy APIs will be fully supported in non-streaming mode for Train.
  • Rewrite docs and examples to use new APIs.

Ray 2.6/7: Deprecate old APIs

  • Full feature parity with global / windowed shuffles using new streaming data APIs.
  • Fully deprecate DatasetPipeline / legacy Train APIs.

Test Plan and Acceptance Criteria

The proposal should discuss how the change will be tested before it can be merged or enabled. It should also include other acceptance criteria including documentation and examples.

  • Unit and integration for new APIs
  • Documentation and examples on new API.