Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air - preprocessor] Add BatchMapper. #23700

Merged
merged 14 commits into from
Apr 14, 2022
73 changes: 59 additions & 14 deletions python/ray/ml/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
import warnings
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -8,12 +10,6 @@
from ray.ml.predictor import DataBatchType


class PreprocessorAlreadyFittedException(RuntimeError):
"""Error raised when the preprocessor cannot be fitted again."""

pass


class PreprocessorNotFittedException(RuntimeError):
"""Error raised when the preprocessor needs to be fitted first."""

Expand All @@ -29,31 +25,68 @@ class Preprocessor(abc.ABC):
fitting, and uses these attributes to implement its normalization transform.
"""

class FitStatus(str, Enum):
"""The fit status of preprocessor."""

NOT_FITTABLE = "NOT_FITTABLE"
NOT_FITTED = "NOT_FITTED"
# Only meaningful for Chain preprocessors.
# At least one contained preprocessor in the chain preprocessor
# is fitted and at least one that can be fitted is not fitted yet.
# This is a state that show up if caller only interacts
# with the chain preprocessor through intended Preprocessor APIs.
PARTIALLY_FITTED = "PARTIALLY_FITTED"
FITTED = "FITTED"

# Preprocessors that do not need to be fitted must override this.
_is_fittable = True

def fit_status(self) -> "Preprocessor.FitStatus":
if not self._is_fittable:
return Preprocessor.FitStatus.NOT_FITTABLE
elif self._check_is_fitted():
return Preprocessor.FitStatus.FITTED
else:
return Preprocessor.FitStatus.NOT_FITTED

def fit(self, dataset: Dataset) -> "Preprocessor":
"""Fit this Preprocessor to the Dataset.

Fitted state attributes will be directly set in the Preprocessor.

Calling it more than once will overwrite all previously fitted state:
``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``.

Args:
dataset: Input dataset.

Returns:
Preprocessor: The fitted Preprocessor with state attributes.
"""
if self.check_is_fitted():
raise PreprocessorAlreadyFittedException(
"`fit` cannot be called multiple times. "
"Create a new Preprocessor to fit a new Dataset."
fit_status = self.fit_status()
if fit_status == Preprocessor.FitStatus.NOT_FITTABLE:
# No-op as there is no state to be fitted.
return self

if fit_status in (
Preprocessor.FitStatus.FITTED,
Preprocessor.FitStatus.PARTIALLY_FITTED,
):
warnings.warn(
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
"`fit` has already been called on the preprocessor (or at least one "
"contained preprocessors if this is a chain). "
"All previously fitted state will be overwritten!"
)

return self._fit(dataset)

def fit_transform(self, dataset: Dataset) -> Dataset:
"""Fit this Preprocessor to the Dataset and then transform the Dataset.

Calling it more than once will overwrite all previously fitted state:
``preprocessor.fit_transform(A).fit_transform(B)``
is equivalent to ``preprocessor.fit_transform(B)``.

Args:
dataset: Input Dataset.

Expand All @@ -71,10 +104,18 @@ def transform(self, dataset: Dataset) -> Dataset:

Returns:
ray.data.Dataset: The transformed Dataset.

Raises:
PreprocessorNotFittedException, if ``fit`` is not called yet.
"""
if self._is_fittable and not self.check_is_fitted():
fit_status = self.fit_status()
if fit_status in (
Preprocessor.FitStatus.PARTIALLY_FITTED,
Preprocessor.FitStatus.NOT_FITTED,
):
raise PreprocessorNotFittedException(
"`fit` must be called before `transform`."
"`fit` must be called before `transform`, "
"or simply use fit_transform() to run both steps"
)
return self._transform(dataset)

Expand All @@ -87,13 +128,17 @@ def transform_batch(self, df: DataBatchType) -> DataBatchType:
Returns:
DataBatchType: The transformed data batch.
"""
if self._is_fittable and not self.check_is_fitted():
fit_status = self.fit_status()
if fit_status in (
Preprocessor.FitStatus.PARTIALLY_FITTED,
Preprocessor.FitStatus.NOT_FITTED,
):
raise PreprocessorNotFittedException(
"`fit` must be called before `transform_batch`."
)
return self._transform_batch(df)

def check_is_fitted(self) -> bool:
def _check_is_fitted(self) -> bool:
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns whether this preprocessor is fitted.

We use the convention that attributes with a trailing ``_`` are set after
Expand Down
2 changes: 2 additions & 0 deletions python/ray/ml/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ray.ml.preprocessors.batch_mapper import BatchMapper
from ray.ml.preprocessors.chain import Chain
from ray.ml.preprocessors.encoder import OrdinalEncoder, OneHotEncoder, LabelEncoder
from ray.ml.preprocessors.imputer import SimpleImputer
from ray.ml.preprocessors.scaler import StandardScaler, MinMaxScaler

__all__ = [
"BatchMapper",
"Chain",
"LabelEncoder",
"MinMaxScaler",
Expand Down
29 changes: 29 additions & 0 deletions python/ray/ml/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Callable, TYPE_CHECKING

from ray.ml.preprocessor import Preprocessor

if TYPE_CHECKING:
import pandas


class BatchMapper(Preprocessor):
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
"""Apply ``fn`` to batches of records of given dataset.

This is meant to be generic and supports low level operation on records.
One could easily leverage this preprocessor to achieve operations like
adding a new column or modifying a column in place.

Args:
fn: The udf function for batch operation.
"""

_is_fittable = False

def __init__(self, fn: Callable[["pandas.DataFrame"], "pandas.DataFrame"]):
self.fn = fn

def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame":
return df.transform(self.fn)

def __repr__(self):
return f"<BatchMapper udf={getattr(self.fn, __name__) or self.fn}>"
30 changes: 26 additions & 4 deletions python/ray/ml/preprocessors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,32 @@ class Chain(Preprocessor):
preprocessors: The preprocessors that should be executed sequentially.
"""

_is_fittable = False
def fit_status(self):
fittable_count = 0
fitted_count = 0
for p in self.preprocessors:
# AIR does not support a chain of chained preprocessors at this point.
# Assert to explicitly call this out.
# This can be revisited if compelling use cases emerge.
assert not isinstance(
p, Chain
), "A chain preprocessor should not contain another chain preprocessor."
if p.fit_status() == Preprocessor.FitStatus.FITTED:
fittable_count += 1
fitted_count += 1
elif p.fit_status() == Preprocessor.FitStatus.NOT_FITTED:
fittable_count += 1
else:
assert p.fit_status() == Preprocessor.FitStatus.NOT_FITTABLE
if fittable_count > 0:
if fitted_count == fittable_count:
return Preprocessor.FitStatus.FITTED
elif fitted_count > 0:
return Preprocessor.FitStatus.PARTIALLY_FITTED
Comment on lines +36 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a valid state, and when would this happen? Is this just when a chain is created that contains some fitted and some unfitted preprocessors? Is that even a valid use case that we should allow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. I don't think this is necessarily a valid state to be in. But one may construct a chain preprocessor incorrectly ending up in this mixed state.
Trying to be defensive and explicit here.
I am also open to have another error to warn explicitly about this mixed state, which should not happen..

else:
return Preprocessor.FitStatus.NOT_FITTED
else:
return Preprocessor.FitStatus.NOT_FITTABLE

def __init__(self, *preprocessors: Preprocessor):
super().__init__()
Expand All @@ -40,8 +65,5 @@ def _transform_batch(self, df: DataBatchType) -> DataBatchType:
df = preprocessor.transform_batch(df)
return df

def check_is_fitted(self) -> bool:
return all(p.check_is_fitted() for p in self.preprocessors)

def __repr__(self):
return f"<Chain preprocessors={self.preprocessors}>"
73 changes: 70 additions & 3 deletions python/ray/ml/tests/test_preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import warnings
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest

import ray
from ray.ml.preprocessor import PreprocessorNotFittedException
from ray.ml.preprocessors import (
BatchMapper,
StandardScaler,
MinMaxScaler,
OrdinalEncoder,
Expand Down Expand Up @@ -75,6 +79,34 @@ def test_standard_scaler():
assert pred_out_df.equals(pred_expected_df)


@patch.object(warnings, "warn")
def test_fit_twice(mocked_warn):
"""Tests that a warning msg should be printed."""
col_a = [-1, 0, 1]
col_b = [1, 3, 5]
col_c = [1, 1, None]
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
ds = ray.data.from_pandas(in_df)

scaler = MinMaxScaler(["B", "C"])

# Fit data.
scaler.fit(ds)
assert scaler.stats_ == {"min(B)": 1, "max(B)": 5, "min(C)": 1, "max(C)": 1}

ds = ds.map_batches(lambda x: x * 2)
# Fit again
scaler.fit(ds)
# Assert that the fitted state is corresponding to the second ds.
assert scaler.stats_ == {"min(B)": 2, "max(B)": 10, "min(C)": 2, "max(C)": 2}
msg = (
"`fit` has already been called on the preprocessor (or at least one "
"contained preprocessors if this is a chain). "
"All previously fitted state will be overwritten!"
)
mocked_warn.assert_called_once_with(msg)
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved


def test_min_max_scaler():
"""Tests basic MinMaxScaler functionality."""
col_a = [-1, 0, 1]
Expand Down Expand Up @@ -500,10 +532,15 @@ def test_chain():
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
ds = ray.data.from_pandas(in_df)

def udf(df):
df["A"] *= 2
return df

batch_mapper = BatchMapper(fn=udf)
imputer = SimpleImputer(["B"])
scaler = StandardScaler(["A", "B"])
encoder = LabelEncoder("C")
chain = Chain(scaler, imputer, encoder)
chain = Chain(scaler, imputer, encoder, batch_mapper)

# Fit data.
chain.fit(ds)
Expand All @@ -524,7 +561,7 @@ def test_chain():
transformed = chain.transform(ds)
out_df = transformed.to_pandas()

processed_col_a = [-1.0, -1.0, 1.0, 1.0]
processed_col_a = [-2.0, -2.0, 2.0, 2.0]
processed_col_b = [0.0, 0.0, 0.0, 0.0]
processed_col_c = [1, 0, 2, 2]
expected_df = pd.DataFrame.from_dict(
Expand All @@ -543,7 +580,7 @@ def test_chain():

pred_out_df = chain.transform_batch(pred_in_df)

pred_processed_col_a = [1, 2, None]
pred_processed_col_a = [2, 4, None]
pred_processed_col_b = [-1.0, 0.0, 1.0]
pred_processed_col_c = [0, 2, None]
pred_expected_df = pd.DataFrame.from_dict(
Expand All @@ -557,6 +594,36 @@ def test_chain():
assert pred_out_df.equals(pred_expected_df)


def test_batch_mapper():
"""Tests batch mapper functionality."""
old_column = [1, 2, 3, 4]
to_be_modified = [1, -1, 1, -1]
in_df = pd.DataFrame.from_dict(
{"old_column": old_column, "to_be_modified": to_be_modified}
)
ds = ray.data.from_pandas(in_df)

def add_and_modify_udf(df: "pd.DataFrame"):
df["new_col"] = df["old_column"] + 1
df["to_be_modified"] *= 2
return df

batch_mapper = BatchMapper(fn=add_and_modify_udf)
batch_mapper.fit(ds)
transformed = batch_mapper.transform(ds)
out_df = transformed.to_pandas()

expected_df = pd.DataFrame.from_dict(
{
"old_column": old_column,
"to_be_modified": [2, -2, 2, -2],
"new_col": [2, 3, 4, 5],
}
)

assert out_df.equals(expected_df)


if __name__ == "__main__":
import sys

Expand Down
2 changes: 1 addition & 1 deletion python/ray/ml/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def preprocess_datasets(self) -> None:

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset and not self.preprocessor.check_is_fitted():
if train_dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there valid use cases in which an already-fitted preprocessor may be passed and we'd rather no-op than error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See @matthewdeng's preference about wanting explicit exception. :)
let's make a decision and stick to it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do that. It's just @matthewdeng has this concern to not silently no-op (even with a warning msg):

Can we raise an exception in the fit()/fit_transfomr() when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print a info or warning msg sounds good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think that Preprocessor itself should error if .fit() is called on an already fitted preprocessor, but I was less sure about whether Train as a user of Preprocessor should let these exceptions happen. I think that @matthewdeng is right, we should error here to ensure that the user doesn't think that an overwriting or incremental fit is happening.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about partially fitted chain? what's a user's options here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline.
@matthewdeng @gjoliver @clarkzinzow PTAL.

self.preprocessor.fit(train_dataset)

# Execute dataset transformations serially for now.
Expand Down