Skip to content

Commit

Permalink
[AIR] Add guide on how implement a custom predictors (#31392)
Browse files Browse the repository at this point in the history
Custom predictors allow users to port their models to do scalable batch inference with Ray, but there's no guide for doing this. This PR adds such a guide.

Signed-off-by: Balaji Veeramani <[email protected]>
Co-authored-by: Amog Kamsetty <[email protected]>
  • Loading branch information
bveeramani and amogkam authored Jan 11, 2023
1 parent 1f51f2c commit 30f8187
Show file tree
Hide file tree
Showing 5 changed files with 462 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ sphinx-sitemap==2.2.0
sphinx-thebe==0.1.1
autodoc_pydantic==1.6.1
sphinxcontrib-redoc==1.6.0
sphinx-tabs

# MyST
myst-parser==0.15.2
Expand Down
4 changes: 4 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"sphinx_thebe",
"sphinxcontrib.autodoc_pydantic",
"sphinxcontrib.redoc",
"sphinx_tabs.tabs",
]

myst_enable_extensions = [
Expand Down Expand Up @@ -91,6 +92,9 @@
copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
copybutton_prompt_is_regexp = True

# By default, tabs can be closed by selecting an open tab. We disable this
# functionality with the `sphinx_tabs_disable_tab_closing` option.
sphinx_tabs_disable_tab_closing = True

# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
# of imports.
Expand Down
95 changes: 95 additions & 0 deletions doc/source/ray-air/doc_code/mxnet_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# fmt: off
# __mxnetpredictor_imports_start__
import os
from typing import Dict, Optional, Union

import mxnet as mx
import numpy as np
from mxnet import gluon

import ray
from ray.air import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.data.preprocessors import BatchMapper
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import Predictor
# __mxnetpredictor_imports_end__


# __mxnetpredictor_signature_start__
class MXNetPredictor(Predictor):
...
# __mxnetpredictor_signature_end__

# __mxnetpredictor_init_start__
def __init__(
self,
net: gluon.Block,
preprocessor: Optional[Preprocessor] = None,
):
self.net = net
super().__init__(preprocessor)
# __mxnetpredictor_init_end__

# __mxnetpredictor_from_checkpoint_start__
@classmethod
def from_checkpoint(
cls,
checkpoint: Checkpoint,
net: gluon.Block,
preprocessor: Optional[Preprocessor] = None,
) -> Predictor:
with checkpoint.as_directory() as directory:
path = os.path.join(directory, "net.params")
net.load_parameters(path)
return cls(net, preprocessor=preprocessor)
# __mxnetpredictor_from_checkpoint_end__

# __mxnetpredictor_predict_numpy_start__
def _predict_numpy(
self,
data: Union[np.ndarray, Dict[str, np.ndarray]],
dtype: Optional[np.dtype] = None,
) -> Dict[str, np.ndarray]:
# If `data` looks like `{"features": array([...])}`, unwrap the `dict` and pass
# the array directly to the model.
if isinstance(data, dict) and len(data) == 1:
data = next(iter(data.values()))

inputs = mx.nd.array(data, dtype=dtype)
outputs = self.net(inputs).asnumpy()

return {"predictions": outputs}
# __mxnetpredictor_predict_numpy_end__


# __mxnetpredictor_model_start__
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True)
# __mxnetpredictor_model_end__

# __mxnetpredictor_checkpoint_start__
os.makedirs("checkpoint", exist_ok=True)
net.save_parameters("checkpoint/net.params")
checkpoint = Checkpoint.from_directory("checkpoint")
# __mxnetpredictor_checkpoint_end__

# __mxnetpredictor_predict_start__
# These images aren't normalized. In practice, normalize images before inference.
dataset = ray.data.read_images(
"s3://anonymous@air-example-data-2/imagenet-sample-images", size=(224, 224)
)


def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
# (B, H, W, C) -> (B, C, H, W)
batch["image"] = batch["image"].transpose(0, 3, 1, 2)
return batch


preprocessor = BatchMapper(preprocess, batch_format="numpy")
predictor = BatchPredictor.from_checkpoint(
checkpoint, MXNetPredictor, net=net, preprocessor=preprocessor
)
predictor.predict(dataset)
# __mxnetpredictor_predict_end__
# fmt: on
72 changes: 72 additions & 0 deletions doc/source/ray-air/doc_code/statsmodel_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# fmt: off
# __statsmodelpredictor_imports_start__
import os
from typing import Optional

import numpy as np # noqa: F401
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.base.model import Results
from statsmodels.regression.linear_model import OLSResults

import ray
from ray.air import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import Predictor
# __statsmodelpredictor_imports_end__


# __statsmodelpredictor_signature_start__
class StatsmodelPredictor(Predictor):
...
# __statsmodelpredictor_signature_end__

# __statsmodelpredictor_init_start__
def __init__(self, results: Results, preprocessor: Optional[Preprocessor] = None):
self.results = results
super().__init__(preprocessor)
# __statsmodelpredictor_init_end__

# __statsmodelpredictor_predict_pandas_start__
def _predict_pandas(self, data: pd.DataFrame) -> pd.DataFrame:
predictions: pd.Series = self.results.predict(data)
return predictions.to_frame(name="predictions")
# __statsmodelpredictor_predict_pandas_end__

# __statsmodelpredictor_from_checkpoint_start__
@classmethod
def from_checkpoint(
cls,
checkpoint: Checkpoint,
filename: str,
preprocessor: Optional[Preprocessor] = None,
) -> Predictor:
with checkpoint.as_directory() as directory:
path = os.path.join(directory, filename)
results = OLSResults.load(path)
return cls(results, preprocessor)
# __statsmodelpredictor_from_checkpoint_end__


# __statsmodelpredictor_model_start__
data: pd.DataFrame = sm.datasets.get_rdataset("Guerry", "HistData").data
results = smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=data).fit()
# __statsmodelpredictor_model_end__

# __statsmodelpredictor_checkpoint_start__
os.makedirs("checkpoint", exist_ok=True)
results.save("checkpoint/guerry.pickle")
checkpoint = Checkpoint.from_directory("checkpoint")
# __statsmodelpredictor_checkpoint_end__

# __statsmodelpredictor_predict_start__
predictor = BatchPredictor.from_checkpoint(
checkpoint, StatsmodelPredictor, filename="guerry.pickle"
)
# This is the same data we trained our model on. Don't do this in practice.
dataset = ray.data.from_pandas(data)
predictor.predict(dataset)
# __statsmodelpredictor_predict_end__
# fmt: on
Loading

0 comments on commit 30f8187

Please sign in to comment.