-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air - preprocessor] Add BatchMapper. #23700
Conversation
python/ray/ml/preprocessor.py
Outdated
raise PreprocessorNotFittedException( | ||
"`fit` must be called before `transform_batch`." | ||
) | ||
return self._transform_batch(df) | ||
|
||
def should_fit(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds more like can_fit(self) or fittable(self) to me.
btw why is it check_is_fitted() and not is_fitted() ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, got it. We can decide on the naming. But semantics are basically:
fittable
is an inherent attribute of the "type" of the preprocessors. It also implies whether afit
method in meaningful at all throughout the entire lifetime of this preprocessor.- should_fit/can_fit depends on the state a preprocessor is currently in (assuming it's fittable).
Exposing check_is_fitted
alone is not enough, as you can see in trainer.py
- it only checks for check_is_fitted
in current impl, which leads to crash in the case of non-fittable preprocessors. That's why the proposal is to add should_fit
.
check_is_fitted
v.s. is_fitted
or can
v.s. should
- I don't have much preference. @clarkzinzow @matthewdeng maybe as original author of the API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with the change. I am just nit-picking the naming.
hope to get things named consistently. thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for should_fit
over can_fit
since it's not indicating an optional operation for fittable preprocessors, it's a necessary operation: if a fittable preprocessor is not fit before calling .transform()
, it will fail. An argument could even be made for needs_fit
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WE can change check_is_fitted
to is_fitted
.
should_fit
functionally is a bit strange to me, at least as a public API. In particular, I want to avoid the case where the user does something like:
if (preprocessor.should_fit()):
preprocessor.fit()
It's not clear how to differentiate the case where the preprocessor is fitted from the case where the preprocessor was already fitted before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdeng hmm, mind elaborating a bit more?
so 3 cases:
- not fittable
- fittable and fitted
- fittable and not fitted yet
should_fit == case 3
if_fitted == 2 conditioned on (2 + 3)
Another way that I can see to work is just to enforce "at most once" fitting semantics internally - and caller doesn't have to call should_fit
before fit
. Which one do you prefer? Or are you proposing an alternative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I lack the context why you want to bundle these 2 things together in the first place.
but in my mind, the most intuitive way is to:
if calling fit() or fit_transform(), and not fittable, throw exception.
if calling fit(), and already fitted, print warning msg, and no-op.
if calling fit_transoform(), and already fitted, print warning msg, then proceed to transform().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we raise an exception in the fit()/fit_transfomr()
when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, the biggest remaining things are:
- We need to modify
Chain
to work correctly with the newshould_fit()
API. - I think that
BatchMapper.fit()
should be a no-op in order forChain
to be able to naively call.fit()
and.fit_transform()
on all of its preprocessors, which should be cleaner.
ray/python/ray/ml/preprocessors/chain.py
Lines 22 to 47 in 8b8afd5
def _fit(self, ds: Dataset) -> Preprocessor: | |
for preprocessor in self.preprocessors[:-1]: | |
ds = preprocessor.fit_transform(ds) | |
self.preprocessors[-1].fit(ds) | |
return self | |
def fit_transform(self, ds: Dataset) -> Dataset: | |
for preprocessor in self.preprocessors: | |
ds = preprocessor.fit_transform(ds) | |
return ds | |
def _transform(self, ds: Dataset) -> Dataset: | |
for preprocessor in self.preprocessors: | |
ds = preprocessor.transform(ds) | |
return ds | |
def _transform_batch(self, df: DataBatchType) -> DataBatchType: | |
for preprocessor in self.preprocessors: | |
df = preprocessor.transform_batch(df) | |
return df | |
def check_is_fitted(self) -> bool: | |
return all(p.check_is_fitted() for p in self.preprocessors) | |
def __repr__(self): | |
return f"<Chain preprocessors={self.preprocessors}>" |
python/ray/ml/preprocessor.py
Outdated
Args: | ||
dataset: Input dataset. | ||
|
||
Returns: | ||
Preprocessor: The fitted Preprocessor with state attributes. | ||
""" | ||
assert self._is_fittable, "One is expected to call `should_fit` before `fit`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also make this a no-op when not self._is_fittable
, which would be more friendly to e.g. chain preprocessors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above.
Looking good, the biggest remaining things are:
- We need to modify
Chain
to work correctly with the newshould_fit()
API.- I think that
BatchMapper.fit()
should be a no-op in order forChain
to be able to naively call.fit()
and.fit_transform()
on all of its preprocessors, which should be cleaner.ray/python/ray/ml/preprocessors/chain.py
Lines 22 to 47 in 8b8afd5
def _fit(self, ds: Dataset) -> Preprocessor: for preprocessor in self.preprocessors[:-1]: ds = preprocessor.fit_transform(ds) self.preprocessors[-1].fit(ds) return self def fit_transform(self, ds: Dataset) -> Dataset: for preprocessor in self.preprocessors: ds = preprocessor.fit_transform(ds) return ds def _transform(self, ds: Dataset) -> Dataset: for preprocessor in self.preprocessors: ds = preprocessor.transform(ds) return ds def _transform_batch(self, df: DataBatchType) -> DataBatchType: for preprocessor in self.preprocessors: df = preprocessor.transform_batch(df) return df def check_is_fitted(self) -> bool: return all(p.check_is_fitted() for p in self.preprocessors) def __repr__(self): return f"<Chain preprocessors={self.preprocessors}>"
@clarkzinzow I see.
Looking at Chain preprocessor, _is_fittable is set to False. Are users supposed to overwrite this when constructing their Chain preprocessor?
python/ray/ml/preprocessor.py
Outdated
@@ -60,6 +64,8 @@ def fit_transform(self, dataset: Dataset) -> Dataset: | |||
Returns: | |||
ray.data.Dataset: The transformed Dataset. | |||
""" | |||
assert self._is_fittable, "One is expected to call `should_fit` before `fit`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this error message looks weird. why don't you check:
assert self.should_fit() here as well?
python/ray/ml/preprocessor.py
Outdated
raise PreprocessorNotFittedException( | ||
"`fit` must be called before `transform_batch`." | ||
) | ||
return self._transform_batch(df) | ||
|
||
def should_fit(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I lack the context why you want to bundle these 2 things together in the first place.
but in my mind, the most intuitive way is to:
if calling fit() or fit_transform(), and not fittable, throw exception.
if calling fit(), and already fitted, print warning msg, and no-op.
if calling fit_transoform(), and already fitted, print warning msg, then proceed to transform().
@gjoliver @matthewdeng @clarkzinzow
|
elif fitted_count > 0: | ||
return Preprocessor.FitStatus.PARTIALLY_FITTED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a valid state, and when would this happen? Is this just when a chain is created that contains some fitted and some unfitted preprocessors? Is that even a valid use case that we should allow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. I don't think this is necessarily a valid state to be in. But one may construct a chain preprocessor incorrectly ending up in this mixed state.
Trying to be defensive and explicit here.
I am also open to have another error to warn explicitly about this mixed state, which should not happen..
@@ -192,7 +192,7 @@ def preprocess_datasets(self) -> None: | |||
|
|||
if self.preprocessor: | |||
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None) | |||
if train_dataset and not self.preprocessor.check_is_fitted(): | |||
if train_dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there valid use cases in which an already-fitted preprocessor may be passed and we'd rather no-op than error here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See @matthewdeng's preference about wanting explicit exception. :)
let's make a decision and stick to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could do that. It's just @matthewdeng has this concern to not silently no-op (even with a warning msg):
Can we raise an exception in the fit()/fit_transfomr() when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print a info or warning msg sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think that Preprocessor
itself should error if .fit()
is called on an already fitted preprocessor, but I was less sure about whether Train as a user of Preprocessor
should let these exceptions happen. I think that @matthewdeng is right, we should error here to ensure that the user doesn't think that an overwriting or incremental fit is happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about partially fitted chain? what's a user's options here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline.
@matthewdeng @gjoliver @clarkzinzow PTAL.
@@ -192,7 +192,7 @@ def preprocess_datasets(self) -> None: | |||
|
|||
if self.preprocessor: | |||
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None) | |||
if train_dataset and not self.preprocessor.check_is_fitted(): | |||
if train_dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, only nits! IMO good to merge after one other ML team reviewer approval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New functionality looks great, thanks for iterating on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - can you update the PR summary and add a description (including the fit status changes)?
Why are these changes needed?
Add BatchMapper preprocessor.
Update the semantics of preprocessor.fit() to allow for multiple fit. This is to follow scikitlearn example.
Introduce FitStatus to explicitly incorporate Chain case.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.