diff --git a/python/ray/ml/preprocessor.py b/python/ray/ml/preprocessor.py index f3814512a2db..77b4ac8b92d8 100644 --- a/python/ray/ml/preprocessor.py +++ b/python/ray/ml/preprocessor.py @@ -1,4 +1,6 @@ import abc +import warnings +from enum import Enum from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -8,12 +10,6 @@ from ray.ml.predictor import DataBatchType -class PreprocessorAlreadyFittedException(RuntimeError): - """Error raised when the preprocessor cannot be fitted again.""" - - pass - - class PreprocessorNotFittedException(RuntimeError): """Error raised when the preprocessor needs to be fitted first.""" @@ -29,24 +25,57 @@ class Preprocessor(abc.ABC): fitting, and uses these attributes to implement its normalization transform. """ + class FitStatus(str, Enum): + """The fit status of preprocessor.""" + + NOT_FITTABLE = "NOT_FITTABLE" + NOT_FITTED = "NOT_FITTED" + # Only meaningful for Chain preprocessors. + # At least one contained preprocessor in the chain preprocessor + # is fitted and at least one that can be fitted is not fitted yet. + # This is a state that show up if caller only interacts + # with the chain preprocessor through intended Preprocessor APIs. + PARTIALLY_FITTED = "PARTIALLY_FITTED" + FITTED = "FITTED" + # Preprocessors that do not need to be fitted must override this. _is_fittable = True + def fit_status(self) -> "Preprocessor.FitStatus": + if not self._is_fittable: + return Preprocessor.FitStatus.NOT_FITTABLE + elif self._check_is_fitted(): + return Preprocessor.FitStatus.FITTED + else: + return Preprocessor.FitStatus.NOT_FITTED + def fit(self, dataset: Dataset) -> "Preprocessor": """Fit this Preprocessor to the Dataset. Fitted state attributes will be directly set in the Preprocessor. + Calling it more than once will overwrite all previously fitted state: + ``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``. + Args: dataset: Input dataset. Returns: Preprocessor: The fitted Preprocessor with state attributes. """ - if self.check_is_fitted(): - raise PreprocessorAlreadyFittedException( - "`fit` cannot be called multiple times. " - "Create a new Preprocessor to fit a new Dataset." + fit_status = self.fit_status() + if fit_status == Preprocessor.FitStatus.NOT_FITTABLE: + # No-op as there is no state to be fitted. + return self + + if fit_status in ( + Preprocessor.FitStatus.FITTED, + Preprocessor.FitStatus.PARTIALLY_FITTED, + ): + warnings.warn( + "`fit` has already been called on the preprocessor (or at least one " + "contained preprocessors if this is a chain). " + "All previously fitted state will be overwritten!" ) return self._fit(dataset) @@ -54,6 +83,10 @@ def fit(self, dataset: Dataset) -> "Preprocessor": def fit_transform(self, dataset: Dataset) -> Dataset: """Fit this Preprocessor to the Dataset and then transform the Dataset. + Calling it more than once will overwrite all previously fitted state: + ``preprocessor.fit_transform(A).fit_transform(B)`` + is equivalent to ``preprocessor.fit_transform(B)``. + Args: dataset: Input Dataset. @@ -71,10 +104,18 @@ def transform(self, dataset: Dataset) -> Dataset: Returns: ray.data.Dataset: The transformed Dataset. + + Raises: + PreprocessorNotFittedException, if ``fit`` is not called yet. """ - if self._is_fittable and not self.check_is_fitted(): + fit_status = self.fit_status() + if fit_status in ( + Preprocessor.FitStatus.PARTIALLY_FITTED, + Preprocessor.FitStatus.NOT_FITTED, + ): raise PreprocessorNotFittedException( - "`fit` must be called before `transform`." + "`fit` must be called before `transform`, " + "or simply use fit_transform() to run both steps" ) return self._transform(dataset) @@ -87,13 +128,17 @@ def transform_batch(self, df: DataBatchType) -> DataBatchType: Returns: DataBatchType: The transformed data batch. """ - if self._is_fittable and not self.check_is_fitted(): + fit_status = self.fit_status() + if fit_status in ( + Preprocessor.FitStatus.PARTIALLY_FITTED, + Preprocessor.FitStatus.NOT_FITTED, + ): raise PreprocessorNotFittedException( "`fit` must be called before `transform_batch`." ) return self._transform_batch(df) - def check_is_fitted(self) -> bool: + def _check_is_fitted(self) -> bool: """Returns whether this preprocessor is fitted. We use the convention that attributes with a trailing ``_`` are set after diff --git a/python/ray/ml/preprocessors/__init__.py b/python/ray/ml/preprocessors/__init__.py index 9e8c32cc86e6..872f24a9f174 100644 --- a/python/ray/ml/preprocessors/__init__.py +++ b/python/ray/ml/preprocessors/__init__.py @@ -1,9 +1,11 @@ +from ray.ml.preprocessors.batch_mapper import BatchMapper from ray.ml.preprocessors.chain import Chain from ray.ml.preprocessors.encoder import OrdinalEncoder, OneHotEncoder, LabelEncoder from ray.ml.preprocessors.imputer import SimpleImputer from ray.ml.preprocessors.scaler import StandardScaler, MinMaxScaler __all__ = [ + "BatchMapper", "Chain", "LabelEncoder", "MinMaxScaler", diff --git a/python/ray/ml/preprocessors/batch_mapper.py b/python/ray/ml/preprocessors/batch_mapper.py new file mode 100644 index 000000000000..abd27a3b5a95 --- /dev/null +++ b/python/ray/ml/preprocessors/batch_mapper.py @@ -0,0 +1,29 @@ +from typing import Callable, TYPE_CHECKING + +from ray.ml.preprocessor import Preprocessor + +if TYPE_CHECKING: + import pandas + + +class BatchMapper(Preprocessor): + """Apply ``fn`` to batches of records of given dataset. + + This is meant to be generic and supports low level operation on records. + One could easily leverage this preprocessor to achieve operations like + adding a new column or modifying a column in place. + + Args: + fn: The udf function for batch operation. + """ + + _is_fittable = False + + def __init__(self, fn: Callable[["pandas.DataFrame"], "pandas.DataFrame"]): + self.fn = fn + + def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame": + return df.transform(self.fn) + + def __repr__(self): + return f"" diff --git a/python/ray/ml/preprocessors/chain.py b/python/ray/ml/preprocessors/chain.py index 72b1ae035fd0..5d89787e973f 100644 --- a/python/ray/ml/preprocessors/chain.py +++ b/python/ray/ml/preprocessors/chain.py @@ -13,7 +13,32 @@ class Chain(Preprocessor): preprocessors: The preprocessors that should be executed sequentially. """ - _is_fittable = False + def fit_status(self): + fittable_count = 0 + fitted_count = 0 + for p in self.preprocessors: + # AIR does not support a chain of chained preprocessors at this point. + # Assert to explicitly call this out. + # This can be revisited if compelling use cases emerge. + assert not isinstance( + p, Chain + ), "A chain preprocessor should not contain another chain preprocessor." + if p.fit_status() == Preprocessor.FitStatus.FITTED: + fittable_count += 1 + fitted_count += 1 + elif p.fit_status() == Preprocessor.FitStatus.NOT_FITTED: + fittable_count += 1 + else: + assert p.fit_status() == Preprocessor.FitStatus.NOT_FITTABLE + if fittable_count > 0: + if fitted_count == fittable_count: + return Preprocessor.FitStatus.FITTED + elif fitted_count > 0: + return Preprocessor.FitStatus.PARTIALLY_FITTED + else: + return Preprocessor.FitStatus.NOT_FITTED + else: + return Preprocessor.FitStatus.NOT_FITTABLE def __init__(self, *preprocessors: Preprocessor): super().__init__() @@ -40,8 +65,5 @@ def _transform_batch(self, df: DataBatchType) -> DataBatchType: df = preprocessor.transform_batch(df) return df - def check_is_fitted(self) -> bool: - return all(p.check_is_fitted() for p in self.preprocessors) - def __repr__(self): return f"" diff --git a/python/ray/ml/tests/test_preprocessors.py b/python/ray/ml/tests/test_preprocessors.py index 4843221ebd88..ab72bfec5f03 100644 --- a/python/ray/ml/tests/test_preprocessors.py +++ b/python/ray/ml/tests/test_preprocessors.py @@ -1,3 +1,6 @@ +import warnings +from unittest.mock import patch + import numpy as np import pandas as pd import pytest @@ -5,6 +8,7 @@ import ray from ray.ml.preprocessor import PreprocessorNotFittedException from ray.ml.preprocessors import ( + BatchMapper, StandardScaler, MinMaxScaler, OrdinalEncoder, @@ -75,6 +79,34 @@ def test_standard_scaler(): assert pred_out_df.equals(pred_expected_df) +@patch.object(warnings, "warn") +def test_fit_twice(mocked_warn): + """Tests that a warning msg should be printed.""" + col_a = [-1, 0, 1] + col_b = [1, 3, 5] + col_c = [1, 1, None] + in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c}) + ds = ray.data.from_pandas(in_df) + + scaler = MinMaxScaler(["B", "C"]) + + # Fit data. + scaler.fit(ds) + assert scaler.stats_ == {"min(B)": 1, "max(B)": 5, "min(C)": 1, "max(C)": 1} + + ds = ds.map_batches(lambda x: x * 2) + # Fit again + scaler.fit(ds) + # Assert that the fitted state is corresponding to the second ds. + assert scaler.stats_ == {"min(B)": 2, "max(B)": 10, "min(C)": 2, "max(C)": 2} + msg = ( + "`fit` has already been called on the preprocessor (or at least one " + "contained preprocessors if this is a chain). " + "All previously fitted state will be overwritten!" + ) + mocked_warn.assert_called_once_with(msg) + + def test_min_max_scaler(): """Tests basic MinMaxScaler functionality.""" col_a = [-1, 0, 1] @@ -500,10 +532,15 @@ def test_chain(): in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c}) ds = ray.data.from_pandas(in_df) + def udf(df): + df["A"] *= 2 + return df + + batch_mapper = BatchMapper(fn=udf) imputer = SimpleImputer(["B"]) scaler = StandardScaler(["A", "B"]) encoder = LabelEncoder("C") - chain = Chain(scaler, imputer, encoder) + chain = Chain(scaler, imputer, encoder, batch_mapper) # Fit data. chain.fit(ds) @@ -524,7 +561,7 @@ def test_chain(): transformed = chain.transform(ds) out_df = transformed.to_pandas() - processed_col_a = [-1.0, -1.0, 1.0, 1.0] + processed_col_a = [-2.0, -2.0, 2.0, 2.0] processed_col_b = [0.0, 0.0, 0.0, 0.0] processed_col_c = [1, 0, 2, 2] expected_df = pd.DataFrame.from_dict( @@ -543,7 +580,7 @@ def test_chain(): pred_out_df = chain.transform_batch(pred_in_df) - pred_processed_col_a = [1, 2, None] + pred_processed_col_a = [2, 4, None] pred_processed_col_b = [-1.0, 0.0, 1.0] pred_processed_col_c = [0, 2, None] pred_expected_df = pd.DataFrame.from_dict( @@ -557,6 +594,36 @@ def test_chain(): assert pred_out_df.equals(pred_expected_df) +def test_batch_mapper(): + """Tests batch mapper functionality.""" + old_column = [1, 2, 3, 4] + to_be_modified = [1, -1, 1, -1] + in_df = pd.DataFrame.from_dict( + {"old_column": old_column, "to_be_modified": to_be_modified} + ) + ds = ray.data.from_pandas(in_df) + + def add_and_modify_udf(df: "pd.DataFrame"): + df["new_col"] = df["old_column"] + 1 + df["to_be_modified"] *= 2 + return df + + batch_mapper = BatchMapper(fn=add_and_modify_udf) + batch_mapper.fit(ds) + transformed = batch_mapper.transform(ds) + out_df = transformed.to_pandas() + + expected_df = pd.DataFrame.from_dict( + { + "old_column": old_column, + "to_be_modified": [2, -2, 2, -2], + "new_col": [2, 3, 4, 5], + } + ) + + assert out_df.equals(expected_df) + + if __name__ == "__main__": import sys diff --git a/python/ray/ml/trainer.py b/python/ray/ml/trainer.py index cf48098e8bf8..89aecdf570a2 100644 --- a/python/ray/ml/trainer.py +++ b/python/ray/ml/trainer.py @@ -245,7 +245,7 @@ def preprocess_datasets(self) -> None: if self.preprocessor: train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None) - if train_dataset and not self.preprocessor.check_is_fitted(): + if train_dataset: self.preprocessor.fit(train_dataset) # Execute dataset transformations serially for now.