diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 83e56f515cb1..745688ffb8c0 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -60,6 +60,7 @@ sphinx-sitemap==2.2.0 sphinx-thebe==0.1.1 autodoc_pydantic==1.6.1 sphinxcontrib-redoc==1.6.0 +sphinx-tabs # MyST myst-parser==0.15.2 diff --git a/doc/source/conf.py b/doc/source/conf.py index 27e72061dca1..b6a1b9a4e140 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -51,6 +51,7 @@ "sphinx_thebe", "sphinxcontrib.autodoc_pydantic", "sphinxcontrib.redoc", + "sphinx_tabs.tabs", ] myst_enable_extensions = [ @@ -91,6 +92,9 @@ copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " copybutton_prompt_is_regexp = True +# By default, tabs can be closed by selecting an open tab. We disable this +# functionality with the `sphinx_tabs_disable_tab_closing` option. +sphinx_tabs_disable_tab_closing = True # There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order # of imports. diff --git a/doc/source/ray-air/doc_code/mxnet_predictor.py b/doc/source/ray-air/doc_code/mxnet_predictor.py new file mode 100644 index 000000000000..55ca0a97eae9 --- /dev/null +++ b/doc/source/ray-air/doc_code/mxnet_predictor.py @@ -0,0 +1,95 @@ +# fmt: off +# __mxnetpredictor_imports_start__ +import os +from typing import Dict, Optional, Union + +import mxnet as mx +import numpy as np +from mxnet import gluon + +import ray +from ray.air import Checkpoint +from ray.data.preprocessor import Preprocessor +from ray.data.preprocessors import BatchMapper +from ray.train.batch_predictor import BatchPredictor +from ray.train.predictor import Predictor +# __mxnetpredictor_imports_end__ + + +# __mxnetpredictor_signature_start__ +class MXNetPredictor(Predictor): + ... + # __mxnetpredictor_signature_end__ + + # __mxnetpredictor_init_start__ + def __init__( + self, + net: gluon.Block, + preprocessor: Optional[Preprocessor] = None, + ): + self.net = net + super().__init__(preprocessor) + # __mxnetpredictor_init_end__ + + # __mxnetpredictor_from_checkpoint_start__ + @classmethod + def from_checkpoint( + cls, + checkpoint: Checkpoint, + net: gluon.Block, + preprocessor: Optional[Preprocessor] = None, + ) -> Predictor: + with checkpoint.as_directory() as directory: + path = os.path.join(directory, "net.params") + net.load_parameters(path) + return cls(net, preprocessor=preprocessor) + # __mxnetpredictor_from_checkpoint_end__ + + # __mxnetpredictor_predict_numpy_start__ + def _predict_numpy( + self, + data: Union[np.ndarray, Dict[str, np.ndarray]], + dtype: Optional[np.dtype] = None, + ) -> Dict[str, np.ndarray]: + # If `data` looks like `{"features": array([...])}`, unwrap the `dict` and pass + # the array directly to the model. + if isinstance(data, dict) and len(data) == 1: + data = next(iter(data.values())) + + inputs = mx.nd.array(data, dtype=dtype) + outputs = self.net(inputs).asnumpy() + + return {"predictions": outputs} +# __mxnetpredictor_predict_numpy_end__ + + +# __mxnetpredictor_model_start__ +net = gluon.model_zoo.vision.resnet50_v1(pretrained=True) +# __mxnetpredictor_model_end__ + +# __mxnetpredictor_checkpoint_start__ +os.makedirs("checkpoint", exist_ok=True) +net.save_parameters("checkpoint/net.params") +checkpoint = Checkpoint.from_directory("checkpoint") +# __mxnetpredictor_checkpoint_end__ + +# __mxnetpredictor_predict_start__ +# These images aren't normalized. In practice, normalize images before inference. +dataset = ray.data.read_images( + "s3://anonymous@air-example-data-2/imagenet-sample-images", size=(224, 224) +) + + +def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + # (B, H, W, C) -> (B, C, H, W) + batch["image"] = batch["image"].transpose(0, 3, 1, 2) + return batch + + +preprocessor = BatchMapper(preprocess, batch_format="numpy") +predictor = BatchPredictor.from_checkpoint( + checkpoint, MXNetPredictor, net=net, preprocessor=preprocessor +) +predictor.predict(dataset) +# __mxnetpredictor_predict_end__ +# fmt: on diff --git a/doc/source/ray-air/doc_code/statsmodel_predictor.py b/doc/source/ray-air/doc_code/statsmodel_predictor.py new file mode 100644 index 000000000000..4b9804fb88f6 --- /dev/null +++ b/doc/source/ray-air/doc_code/statsmodel_predictor.py @@ -0,0 +1,72 @@ +# fmt: off +# __statsmodelpredictor_imports_start__ +import os +from typing import Optional + +import numpy as np # noqa: F401 +import pandas as pd +import statsmodels.api as sm +import statsmodels.formula.api as smf +from statsmodels.base.model import Results +from statsmodels.regression.linear_model import OLSResults + +import ray +from ray.air import Checkpoint +from ray.data.preprocessor import Preprocessor +from ray.train.batch_predictor import BatchPredictor +from ray.train.predictor import Predictor +# __statsmodelpredictor_imports_end__ + + +# __statsmodelpredictor_signature_start__ +class StatsmodelPredictor(Predictor): + ... + # __statsmodelpredictor_signature_end__ + + # __statsmodelpredictor_init_start__ + def __init__(self, results: Results, preprocessor: Optional[Preprocessor] = None): + self.results = results + super().__init__(preprocessor) + # __statsmodelpredictor_init_end__ + + # __statsmodelpredictor_predict_pandas_start__ + def _predict_pandas(self, data: pd.DataFrame) -> pd.DataFrame: + predictions: pd.Series = self.results.predict(data) + return predictions.to_frame(name="predictions") + # __statsmodelpredictor_predict_pandas_end__ + + # __statsmodelpredictor_from_checkpoint_start__ + @classmethod + def from_checkpoint( + cls, + checkpoint: Checkpoint, + filename: str, + preprocessor: Optional[Preprocessor] = None, + ) -> Predictor: + with checkpoint.as_directory() as directory: + path = os.path.join(directory, filename) + results = OLSResults.load(path) + return cls(results, preprocessor) +# __statsmodelpredictor_from_checkpoint_end__ + + +# __statsmodelpredictor_model_start__ +data: pd.DataFrame = sm.datasets.get_rdataset("Guerry", "HistData").data +results = smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=data).fit() +# __statsmodelpredictor_model_end__ + +# __statsmodelpredictor_checkpoint_start__ +os.makedirs("checkpoint", exist_ok=True) +results.save("checkpoint/guerry.pickle") +checkpoint = Checkpoint.from_directory("checkpoint") +# __statsmodelpredictor_checkpoint_end__ + +# __statsmodelpredictor_predict_start__ +predictor = BatchPredictor.from_checkpoint( + checkpoint, StatsmodelPredictor, filename="guerry.pickle" +) +# This is the same data we trained our model on. Don't do this in practice. +dataset = ray.data.from_pandas(data) +predictor.predict(dataset) +# __statsmodelpredictor_predict_end__ +# fmt: on diff --git a/doc/source/ray-air/predictors.rst b/doc/source/ray-air/predictors.rst index 58f471ffdccd..8145c531284b 100644 --- a/doc/source/ray-air/predictors.rst +++ b/doc/source/ray-air/predictors.rst @@ -116,6 +116,295 @@ Text Coming soon! +Developer Guide: Implementing your own Predictor +------------------------------------------------ + +If you're using an unsupported framework, or if built-in predictors are too inflexible, +you may need to implement a custom predictor. + +To implement a custom :class:`~ray.train.predictor.Predictor`, +subclass :class:`~ray.train.predictor.Predictor` and implement: + +* :meth:`~ray.train.predictor.Predictor.__init__` +* :meth:`~ray.train.predictor.Predictor._predict_numpy` or :meth:`~ray.train.predictor.Predictor._predict_pandas` +* :meth:`~ray.train.predictor.Predictor.from_checkpoint` + +.. tip:: + You don't need to implement both + :meth:`~ray.train.predictor.Predictor._predict_numpy` and + :meth:`~ray.train.predictor.Predictor._predict_pandas`. Pick the method that's + easiest to implement. If both are implemented, override + :meth:`~ray.train.predictor.Predictor.preferred_batch_format` to specify which format + is more performant. This allows upstream producers to choose the best format. + +Examples +~~~~~~~~ + +We'll walk through how to implement a predictor for two frameworks: + +* MXNet -- a deep learning framework like Torch. +* statsmodel -- a Python library that provides regression and linear models. + +For more examples, read the source code of built-in predictors like +:class:`~ray.train.torch.TorchPredictor`, +:class:`~ray.train.xgboost.XGBoostPredictor`, and +:class:`~ray.train.sklearn.SklearnPredictor`. + +Before you begin +**************** + +.. tabs:: + + .. group-tab:: MXNet + + First, install MXNet and Ray AIR. + + .. code-block:: console + + pip install mxnet 'ray[air]' + + Then, import the objects required for this example. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_imports_start__ + :end-before: __mxnetpredictor_imports_end__ + + Finally, create a stub for the `MXNetPredictor` class. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_signature_start__ + :end-before: __mxnetpredictor_signature_end__ + + .. group-tab:: statsmodel + + First, install statsmodel and Ray AIR. + + .. code-block:: console + + pip install statsmodel 'ray[air]' + + Then, import the objects required for this example. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_imports_start__ + :end-before: __statsmodelpredictor_imports_end__ + + Finally, create a stub the `StatsmodelPredictor` class. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_signature_start__ + :end-before: __statsmodelpredictor_signature_end__ + +Create a model +************** + +.. tabs:: + + .. group-tab:: MXNet + + You'll need to pass a model to the ``MXNetPredictor`` constructor. + + To create the model, load a pre-trained computer vision model from the MXNet + model zoo. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_model_start__ + :end-before: __mxnetpredictor_model_end__ + + .. group-tab:: statsmodel + + You'll need to pass a model to the ``StatsmodelPredictor`` constructor. + + To create the model, fit a linear model on the + `Guerry dataset `_. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_model_start__ + :end-before: __statsmodelpredictor_model_end__ + + +Implement `__init__` +******************** + +.. tabs:: + + .. group-tab:: MXNet + + Use the constructor to set instance attributes required for prediction. In + the code snippet below, we assign the model to an attribute named ``net``. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_init_start__ + :end-before: __mxnetpredictor_init_end__ + + .. warning:: + You must call the base class' constructor; otherwise, + `Predictor.predict ` raises a + ``NotImplementedError``. + + .. group-tab:: statsmodel + + Use the constructor to set instance attributes required for prediction. In + the code snippet below, we assign the fitted model to an attribute named + ``results``. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_init_start__ + :end-before: __statsmodelpredictor_init_end__ + + .. warning:: + You must call the base class' constructor; otherwise, + `Predictor.predict ` raises a + ``NotImplementedError``. + +Implement `from_checkpoint` +*************************** + +.. tabs:: + + .. group-tab:: MXNet + + :meth:`~ray.train.predictor.from_checkpoint` creates a + :class:`~ray.train.predictor.Predictor` from a + :class:`~ray.air.checkpoint.Checkpoint`. + + Before implementing :meth:`~ray.train.predictor.from_checkpoint`, + save the model parameters to a directory, and create a + :class:`~ray.air.checkpoint.Checkpoint` from that directory. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_checkpoint_start__ + :end-before: __mxnetpredictor_checkpoint_end__ + + Then, implement :meth:`~ray.train.predictor.from_checkpoint`. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_from_checkpoint_start__ + :end-before: __mxnetpredictor_from_checkpoint_end__ + + .. group-tab:: statsmodel + + :meth:`~ray.train.predictor.from_checkpoint` creates a + :class:`~ray.train.predictor.Predictor` from a + :class:`~ray.air.checkpoint.Checkpoint`. + + Before implementing :meth:`~ray.train.predictor.from_checkpoint`, + save the fitten model to a directory, and create a + :class:`~ray.air.checkpoint.Checkpoint` from that directory. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_checkpoint_start__ + :end-before: __statsmodelpredictor_checkpoint_end__ + + Then, implement :meth:`~ray.train.predictor.from_checkpoint`. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_from_checkpoint_start__ + :end-before: __statsmodelpredictor_from_checkpoint_end__ + +Implement `_predict_numpy` or `_predict_pandas` +*********************************************** + +.. tabs:: + + .. group-tab:: MXNet + + Because MXNet models accept tensors as input, you should implement + :meth:`~ray.train.predictor.Predictor._predict_numpy`. + + :meth:`~ray.train.predictor.Predictor._predict_numpy` performs inference on a + batch of NumPy data. It accepts a ``np.ndarray`` or ``dict[str, np.ndarray]`` as + input and returns a ``np.ndarray`` or ``dict[str, np.ndarray]`` as output. + + The input type is determined by the type of :class:`~ray.data.Dataset` passed to + :meth:`BatchPredictor.predict `. + If your dataset has columns, the input is a ``dict``; otherwise, the input is a + ``np.ndarray``. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_predict_numpy_start__ + :end-before: __mxnetpredictor_predict_numpy_end__ + + .. group-tab:: statsmodel + + Because your OLS model accepts dataframes as input, you should implement + :meth:`~ray.train.predictor.Predictor._predict_pandas`. + + :meth:`~ray.train.predictor.Predictor._predict_pandas` performs inference on a + batch of pandas data. It accepts a ``pandas.DataFrame`` as input and return a + ``pandas.DataFrame`` as output. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_predict_pandas_start__ + :end-before: __statsmodelpredictor_predict_pandas_end__ + + +Perform inference +***************** + +.. tabs:: + + .. group-tab:: MXNet + + To perform inference with the completed ``MXNetPredictor``: + + 1. Create a :class:`~ray.train.batch_predictor.BatchPredictor` from your + checkpoint. + 2. Read sample images into a :class:`~ray.data.Dataset`. + 3. Call :class:`~ray.train.batch_predictor.BatchPredictor.predict` to classify + the images in the dataset. + + .. literalinclude:: doc_code/mxnet_predictor.py + :language: python + :dedent: + :start-after: __mxnetpredictor_predict_start__ + :end-before: __mxnetpredictor_predict_end__ + + .. group-tab:: statsmodel + + To perform inference with the completed ``StatsmodelPredictor``: + + 1. Create a :class:`~ray.train.batch_predictor.BatchPredictor` from your + checkpoint. + 2. Read the Guerry dataset into a :class:`~ray.data.Dataset`. + 3. Call :class:`~ray.train.batch_predictor.BatchPredictor.predict` to perform + regression on the samples in the dataset. + + .. literalinclude:: doc_code/statsmodel_predictor.py + :language: python + :dedent: + :start-after: __statsmodelpredictor_predict_start__ + :end-before: __statsmodelpredictor_predict_end__ + + .. _pipelined-prediction: Lazy/Pipelined Prediction (experimental) @@ -140,13 +429,4 @@ Execution can be triggered by pulling from the pipeline, as shown in the example Online Inference ---------------- -Check out the :ref:`air-serving-guide` for details on how to perform online inference with AIR. - - -Developer Guide: Implementing your own Predictor ------------------------------------------------- -To implement a new Predictor for your particular framework, you should subclass the base ``Predictor`` and implement the following two methods: - -1. ``_predict_pandas``: Given a pandas.DataFrame input, return a pandas.DataFrame containing predictions. -2. ``from_checkpoint``: Logic for creating a Predictor from an :ref:`AIR Checkpoint `. -3. Optionally ``_predict_numpy`` for better performance when working with tensor data to avoid extra copies from Pandas conversions. +Check out the :ref:`air-serving-guide` for details on how to perform online inference with AIR. \ No newline at end of file