Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air - preprocessor] Add BatchMapper. #23700

Merged
merged 14 commits into from
Apr 14, 2022
55 changes: 50 additions & 5 deletions python/ray/ml/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -29,21 +30,50 @@ class Preprocessor(abc.ABC):
fitting, and uses these attributes to implement its normalization transform.
"""

class FitStatus(str, Enum):
"""The fit status of preprocessor."""

NOT_FITTABLE = "NOT_FITTABLE"
NOT_FITTED = "NOT_FITTED"
# Only meaningful for Chain preprocessors
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
PARTIALLY_FITTED = "PARTIAL_FITTED"
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
FITTED = "FITTED"

# Preprocessors that do not need to be fitted must override this.
_is_fittable = True

def fit_status(self) -> "Preprocessor.FitStatus":
if not self._is_fittable:
return Preprocessor.FitStatus.NOT_FITTABLE
elif self._check_is_fitted():
return Preprocessor.FitStatus.FITTED
else:
return Preprocessor.FitStatus.NOT_FITTED

def fit(self, dataset: Dataset) -> "Preprocessor":
"""Fit this Preprocessor to the Dataset.

Fitted state attributes will be directly set in the Preprocessor.
Only meant to be called at most once if the preprocessor is fittable

Args:
dataset: Input dataset.

Returns:
Preprocessor: The fitted Preprocessor with state attributes.

Raises:
PreprocessorAlreadyFittedException, if already fitted once.
"""
if self.check_is_fitted():
fit_status = self.fit_status()
if fit_status == Preprocessor.FitStatus.NOT_FITTABLE:
# Just return. This makes Chain Preprocessor easier.
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
return self

if fit_status in (
Preprocessor.FitStatus.FITTED,
Preprocessor.FitStatus.PARTIALLY_FITTED,
):
raise PreprocessorAlreadyFittedException(
"`fit` cannot be called multiple times. "
"Create a new Preprocessor to fit a new Dataset."
Expand All @@ -59,6 +89,9 @@ def fit_transform(self, dataset: Dataset) -> Dataset:

Returns:
ray.data.Dataset: The transformed Dataset.

Raises:
PreprocessorAlreadyFittedException, if already fitted once.
"""
self.fit(dataset)
return self.transform(dataset)
Expand All @@ -71,10 +104,18 @@ def transform(self, dataset: Dataset) -> Dataset:

Returns:
ray.data.Dataset: The transformed Dataset.

Raises:
PreprocessorNotFittedException, if ``fit`` is not called yet.
"""
if self._is_fittable and not self.check_is_fitted():
fit_status = self.fit_status()
if fit_status in (
Preprocessor.FitStatus.PARTIALLY_FITTED,
Preprocessor.FitStatus.NOT_FITTED,
):
raise PreprocessorNotFittedException(
"`fit` must be called before `transform`."
"`fit` must be called before `transform`. "
"Or simply use fit_transform() to run both steps"
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
)
return self._transform(dataset)

Expand All @@ -87,13 +128,17 @@ def transform_batch(self, df: DataBatchType) -> DataBatchType:
Returns:
DataBatchType: The transformed data batch.
"""
if self._is_fittable and not self.check_is_fitted():
fit_status = self.fit_status()
if fit_status in (
Preprocessor.FitStatus.PARTIALLY_FITTED,
Preprocessor.FitStatus.NOT_FITTED,
):
raise PreprocessorNotFittedException(
"`fit` must be called before `transform_batch`."
)
return self._transform_batch(df)

def check_is_fitted(self) -> bool:
def _check_is_fitted(self) -> bool:
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns whether this preprocessor is fitted.

We use the convention that attributes with a trailing ``_`` are set after
Expand Down
2 changes: 2 additions & 0 deletions python/ray/ml/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ray.ml.preprocessors.batch_mapper import BatchMapper
from ray.ml.preprocessors.chain import Chain
from ray.ml.preprocessors.encoder import OrdinalEncoder, OneHotEncoder, LabelEncoder
from ray.ml.preprocessors.imputer import SimpleImputer
from ray.ml.preprocessors.scaler import StandardScaler, MinMaxScaler

__all__ = [
"BatchMapper",
"Chain",
"LabelEncoder",
"MinMaxScaler",
Expand Down
29 changes: 29 additions & 0 deletions python/ray/ml/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Callable, TYPE_CHECKING

from ray.ml.preprocessor import Preprocessor

if TYPE_CHECKING:
import pandas


class BatchMapper(Preprocessor):
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
"""Apply ``fn`` to batches of records of given dataset.

This is meant to be generic and supports low level operation on records.
One could easily leverage this preprocessor to achieve operations like
adding a new column or modifying a column in place.

Args:
fn: The udf function for batch operation.
"""

_is_fittable = False

def __init__(self, fn: Callable[["pandas.DataFrame"], "pandas.DataFrame"]):
self.fn = fn

def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame":
return df.transform(self.fn)

def __repr__(self):
return f"<BatchMapper udf={self.fn}>"
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 24 additions & 4 deletions python/ray/ml/preprocessors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,30 @@ class Chain(Preprocessor):
preprocessors: The preprocessors that should be executed sequentially.
"""

_is_fittable = False
def fit_status(self):
fittable_count = 0
fitted_count = 0
for p in self.preprocessors:
# AIR does not support a chain of chained preprocessors at this point.
# Assert to explicitly call this out.
# This can be revisited if compelling use cases emerge.
assert not isinstance(p, Chain)
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
if p.fit_status() == Preprocessor.FitStatus.FITTED:
fittable_count += 1
fitted_count += 1
elif p.fit_status() == Preprocessor.FitStatus.NOT_FITTED:
fittable_count += 1
else:
assert p.fit_status() == Preprocessor.FitStatus.NOT_FITTABLE
if fittable_count > 0:
if fitted_count == fittable_count:
return Preprocessor.FitStatus.FITTED
elif fitted_count > 0:
return Preprocessor.FitStatus.PARTIALLY_FITTED
Comment on lines +36 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a valid state, and when would this happen? Is this just when a chain is created that contains some fitted and some unfitted preprocessors? Is that even a valid use case that we should allow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. I don't think this is necessarily a valid state to be in. But one may construct a chain preprocessor incorrectly ending up in this mixed state.
Trying to be defensive and explicit here.
I am also open to have another error to warn explicitly about this mixed state, which should not happen..

else:
return Preprocessor.FitStatus.NOT_FITTED
else:
return Preprocessor.FitStatus.NOT_FITTABLE

def __init__(self, *preprocessors: Preprocessor):
super().__init__()
Expand All @@ -40,8 +63,5 @@ def _transform_batch(self, df: DataBatchType) -> DataBatchType:
df = preprocessor.transform_batch(df)
return df

def check_is_fitted(self) -> bool:
return all(p.check_is_fitted() for p in self.preprocessors)

def __repr__(self):
return f"<Chain preprocessors={self.preprocessors}>"
42 changes: 39 additions & 3 deletions python/ray/ml/tests/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ray
from ray.ml.preprocessor import PreprocessorNotFittedException
from ray.ml.preprocessors import (
BatchMapper,
StandardScaler,
MinMaxScaler,
OrdinalEncoder,
Expand Down Expand Up @@ -500,10 +501,15 @@ def test_chain():
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
ds = ray.data.from_pandas(in_df)

def udf(df):
df["A"] *= 2
return df

batch_mapper = BatchMapper(fn=udf)
imputer = SimpleImputer(["B"])
scaler = StandardScaler(["A", "B"])
encoder = LabelEncoder("C")
chain = Chain(scaler, imputer, encoder)
chain = Chain(scaler, imputer, encoder, batch_mapper)

# Fit data.
chain.fit(ds)
Expand All @@ -524,7 +530,7 @@ def test_chain():
transformed = chain.transform(ds)
out_df = transformed.to_pandas()

processed_col_a = [-1.0, -1.0, 1.0, 1.0]
processed_col_a = [-2.0, -2.0, 2.0, 2.0]
processed_col_b = [0.0, 0.0, 0.0, 0.0]
processed_col_c = [1, 0, 2, 2]
expected_df = pd.DataFrame.from_dict(
Expand All @@ -543,7 +549,7 @@ def test_chain():

pred_out_df = chain.transform_batch(pred_in_df)

pred_processed_col_a = [1, 2, None]
pred_processed_col_a = [2, 4, None]
pred_processed_col_b = [-1.0, 0.0, 1.0]
pred_processed_col_c = [0, 2, None]
pred_expected_df = pd.DataFrame.from_dict(
Expand All @@ -557,6 +563,36 @@ def test_chain():
assert pred_out_df.equals(pred_expected_df)


def test_batch_mapper():
"""Tests batch mapper functionality."""
old_column = [1, 2, 3, 4]
to_be_modified = [1, -1, 1, -1]
in_df = pd.DataFrame.from_dict(
{"old_column": old_column, "to_be_modified": to_be_modified}
)
ds = ray.data.from_pandas(in_df)

def add_and_modify_udf(df: "pd.DataFrame"):
df["new_col"] = df["old_column"] + 1
df["to_be_modified"] *= 2
return df

batch_mapper = BatchMapper(fn=add_and_modify_udf)
batch_mapper.fit(ds)
transformed = batch_mapper.transform(ds)
out_df = transformed.to_pandas()

expected_df = pd.DataFrame.from_dict(
{
"old_column": old_column,
"to_be_modified": [2, -2, 2, -2],
"new_col": [2, 3, 4, 5],
}
)

assert out_df.equals(expected_df)


if __name__ == "__main__":
import sys

Expand Down
2 changes: 1 addition & 1 deletion python/ray/ml/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def preprocess_datasets(self) -> None:

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset and not self.preprocessor.check_is_fitted():
if train_dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there valid use cases in which an already-fitted preprocessor may be passed and we'd rather no-op than error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See @matthewdeng's preference about wanting explicit exception. :)
let's make a decision and stick to it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do that. It's just @matthewdeng has this concern to not silently no-op (even with a warning msg):

Can we raise an exception in the fit()/fit_transfomr() when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print a info or warning msg sounds good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think that Preprocessor itself should error if .fit() is called on an already fitted preprocessor, but I was less sure about whether Train as a user of Preprocessor should let these exceptions happen. I think that @matthewdeng is right, we should error here to ensure that the user doesn't think that an overwriting or incremental fit is happening.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about partially fitted chain? what's a user's options here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline.
@matthewdeng @gjoliver @clarkzinzow PTAL.

self.preprocessor.fit(train_dataset)

# Execute dataset transformations serially for now.
Expand Down