-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Add Scaling Config validation #23889
[AIR] Add Scaling Config validation #23889
Conversation
python/ray/ml/config.py
Outdated
scaling_config_arg_name: Name of the ScalingConfig argument to be used | ||
in the exception message. | ||
exc_obj_name: Name of the object calling this method to be used | ||
in the exception message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in the previous PR, let's remove these arguments. The enclosing method can re-raise the exception so that the affected key names are in the stack trace, but we should avoid passing strings just for raising error messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can return a list of bad keys instead, so that the enclosing method can raise an exception with a nice message. I think a single exception would be better from the user perspective than a long stack trace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be open to raise a custom exception here that stores the keys as an attribute or a hint on how to resolve things.
ray.air.exceptions.ConfigError
or so
@krfricke updated, PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Yard1, this looks great so far! I think we should also do some more refactoring to make the allowed keys explicit for each Trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on @amogkam 's suggestions. I think you want this end state right?
class Trainer:
_scaling_config_supported_keys = [] # Not scalable by default
@classmethod
def validate_config(scaling_config):
... default validator ...
class XGBoostTrainer:
_scaling_config_supported_keys = ["num_workers", "use_gpu"]
I'd also like to see this PR show the whole integration path with Trainer, rather than adding a utility function without showing how works with other classes. This will make it easier to review the interfaces rather than the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the PR to address the comments from the last feedback round:
- The validation method has been split up in a dict and dataclass part
- It has been moved to a utility function and generalized to arbitrary dataclass and dict objects
- The Trainer now has a private method to return a validate scaling config dataclass
- The code is being used in the DP/GBDP trainer classes.
I'll also update the sklearn trainer to utilize this. PTAL
ScalingConfigDataClass.validate_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks @krfricke!
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class( | ||
self.scaling_config | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this check be in Trainer
itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine, we already call it in Trainer as well (in as_trainable
). This is more of a way to get the dataclass inside the training loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for making the changes!
Needs an approval from @krfricke before we can merge. |
Implements `SklearnTrainer` and `SklearnPredictor`. Full parallelism with joblib + support for GPU enabled estimators like cuML. Interface has been modified slightly by addition of several arguments, which were required for full functionality. I haven't tested cuML yet, will do it later. Depends on #23889 Co-authored-by: Kai Fricke <[email protected]>
Why are these changes needed?
Adds a
ScalingConfigDataClass.validate_config
classmethod to allow for a generic way of validating ScalingConfigs by allowing only certain keys.Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.