diff --git a/doc/source/ray-air/package-ref.rst b/doc/source/ray-air/package-ref.rst index cfad36f71ac1..a2bc4e54eb53 100644 --- a/doc/source/ray-air/package-ref.rst +++ b/doc/source/ray-air/package-ref.rst @@ -35,6 +35,8 @@ Trainer .. autoclass:: ray.train.trainer.BaseTrainer :members: + .. automethod:: __init__ + Abstract Classes ################ @@ -42,10 +44,14 @@ Abstract Classes :members: :show-inheritance: + .. automethod:: __init__ + .. autoclass:: ray.train.gbdt_trainer.GBDTTrainer :members: :show-inheritance: + .. automethod:: __init__ + .. _air-results-ref: Training Result @@ -128,22 +134,46 @@ Trainer and Predictor Integrations XGBoost ####### +.. autoclass:: ray.train.xgboost.XGBoostTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.xgboost :members: + :exclude-members: XGBoostTrainer :show-inheritance: LightGBM ######## +.. autoclass:: ray.train.lightgbm.LightGBMTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.lightgbm :members: + :exclude-members: LightGBMTrainer :show-inheritance: TensorFlow ########## +.. autoclass:: ray.train.tensorflow.TensorflowTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.tensorflow :members: + :exclude-members: TensorflowTrainer :show-inheritance: .. _air-pytorch-ref: @@ -151,29 +181,61 @@ TensorFlow PyTorch ####### +.. autoclass:: ray.train.torch.TorchTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.torch :members: + :exclude-members: TorchTrainer :show-inheritance: Horovod ####### +.. autoclass:: ray.train.horovod.HorovodTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.horovod :members: + :exclude-members: HorovodTrainer :show-inheritance: HuggingFace ########### +.. autoclass:: ray.train.huggingface.HuggingFaceTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.huggingface :members: + :exclude-members: HuggingFaceTrainer :show-inheritance: Scikit-Learn ############ +.. autoclass:: ray.train.sklearn.SklearnTrainer + :members: + :show-inheritance: + + .. automethod:: __init__ + + .. automodule:: ray.train.sklearn :members: + :exclude-members: SklearnTrainer :show-inheritance: diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 1f5a00735c94..e5fe65d05475 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -183,7 +183,7 @@ def __repr__(self): return f"<{self.__class__.__name__}>" def __new__(cls, *args, **kwargs): - """Store the init args as attributes so this can be merged with Tune hparams.""" + # Store the init args as attributes so this can be merged with Tune hparams. trainer = super(BaseTrainer, cls).__new__(cls) parameters = inspect.signature(cls.__init__).parameters parameters = list(parameters.keys()) @@ -252,7 +252,7 @@ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfi def setup(self) -> None: """Called during fit() to perform initial setup on the Trainer. - Note: this method is run on a remote process. + .. note:: This method is run on a remote process. This method will not be called on the driver, so any expensive setup operations should be placed here and not in ``__init__``. @@ -265,7 +265,7 @@ def setup(self) -> None: def preprocess_datasets(self) -> None: """Called during fit() to preprocess dataset attributes with preprocessor. - Note: This method is run on a remote process. + .. note:: This method is run on a remote process. This method is called prior to entering the training_loop. @@ -310,15 +310,16 @@ def training_loop(self) -> None: this training loop. Example: - .. code-block: python - from ray.train.trainer import BaseTrainer + .. code-block: python - class MyTrainer(BaseTrainer): - def training_loop(self): - for epoch_idx in range(5): - ... - session.report({"epoch": epoch_idx}) + from ray.train.trainer import BaseTrainer + + class MyTrainer(BaseTrainer): + def training_loop(self): + for epoch_idx in range(5): + ... + session.report({"epoch": epoch_idx}) """ raise NotImplementedError