-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AIR] Add guide on how implement a custom predictors (#31392)
Custom predictors allow users to port their models to do scalable batch inference with Ray, but there's no guide for doing this. This PR adds such a guide. Signed-off-by: Balaji Veeramani <[email protected]> Co-authored-by: Amog Kamsetty <[email protected]>
- Loading branch information
1 parent
1f51f2c
commit 30f8187
Showing
5 changed files
with
462 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# fmt: off | ||
# __mxnetpredictor_imports_start__ | ||
import os | ||
from typing import Dict, Optional, Union | ||
|
||
import mxnet as mx | ||
import numpy as np | ||
from mxnet import gluon | ||
|
||
import ray | ||
from ray.air import Checkpoint | ||
from ray.data.preprocessor import Preprocessor | ||
from ray.data.preprocessors import BatchMapper | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.predictor import Predictor | ||
# __mxnetpredictor_imports_end__ | ||
|
||
|
||
# __mxnetpredictor_signature_start__ | ||
class MXNetPredictor(Predictor): | ||
... | ||
# __mxnetpredictor_signature_end__ | ||
|
||
# __mxnetpredictor_init_start__ | ||
def __init__( | ||
self, | ||
net: gluon.Block, | ||
preprocessor: Optional[Preprocessor] = None, | ||
): | ||
self.net = net | ||
super().__init__(preprocessor) | ||
# __mxnetpredictor_init_end__ | ||
|
||
# __mxnetpredictor_from_checkpoint_start__ | ||
@classmethod | ||
def from_checkpoint( | ||
cls, | ||
checkpoint: Checkpoint, | ||
net: gluon.Block, | ||
preprocessor: Optional[Preprocessor] = None, | ||
) -> Predictor: | ||
with checkpoint.as_directory() as directory: | ||
path = os.path.join(directory, "net.params") | ||
net.load_parameters(path) | ||
return cls(net, preprocessor=preprocessor) | ||
# __mxnetpredictor_from_checkpoint_end__ | ||
|
||
# __mxnetpredictor_predict_numpy_start__ | ||
def _predict_numpy( | ||
self, | ||
data: Union[np.ndarray, Dict[str, np.ndarray]], | ||
dtype: Optional[np.dtype] = None, | ||
) -> Dict[str, np.ndarray]: | ||
# If `data` looks like `{"features": array([...])}`, unwrap the `dict` and pass | ||
# the array directly to the model. | ||
if isinstance(data, dict) and len(data) == 1: | ||
data = next(iter(data.values())) | ||
|
||
inputs = mx.nd.array(data, dtype=dtype) | ||
outputs = self.net(inputs).asnumpy() | ||
|
||
return {"predictions": outputs} | ||
# __mxnetpredictor_predict_numpy_end__ | ||
|
||
|
||
# __mxnetpredictor_model_start__ | ||
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True) | ||
# __mxnetpredictor_model_end__ | ||
|
||
# __mxnetpredictor_checkpoint_start__ | ||
os.makedirs("checkpoint", exist_ok=True) | ||
net.save_parameters("checkpoint/net.params") | ||
checkpoint = Checkpoint.from_directory("checkpoint") | ||
# __mxnetpredictor_checkpoint_end__ | ||
|
||
# __mxnetpredictor_predict_start__ | ||
# These images aren't normalized. In practice, normalize images before inference. | ||
dataset = ray.data.read_images( | ||
"s3://anonymous@air-example-data-2/imagenet-sample-images", size=(224, 224) | ||
) | ||
|
||
|
||
def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | ||
# (B, H, W, C) -> (B, C, H, W) | ||
batch["image"] = batch["image"].transpose(0, 3, 1, 2) | ||
return batch | ||
|
||
|
||
preprocessor = BatchMapper(preprocess, batch_format="numpy") | ||
predictor = BatchPredictor.from_checkpoint( | ||
checkpoint, MXNetPredictor, net=net, preprocessor=preprocessor | ||
) | ||
predictor.predict(dataset) | ||
# __mxnetpredictor_predict_end__ | ||
# fmt: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# fmt: off | ||
# __statsmodelpredictor_imports_start__ | ||
import os | ||
from typing import Optional | ||
|
||
import numpy as np # noqa: F401 | ||
import pandas as pd | ||
import statsmodels.api as sm | ||
import statsmodels.formula.api as smf | ||
from statsmodels.base.model import Results | ||
from statsmodels.regression.linear_model import OLSResults | ||
|
||
import ray | ||
from ray.air import Checkpoint | ||
from ray.data.preprocessor import Preprocessor | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.predictor import Predictor | ||
# __statsmodelpredictor_imports_end__ | ||
|
||
|
||
# __statsmodelpredictor_signature_start__ | ||
class StatsmodelPredictor(Predictor): | ||
... | ||
# __statsmodelpredictor_signature_end__ | ||
|
||
# __statsmodelpredictor_init_start__ | ||
def __init__(self, results: Results, preprocessor: Optional[Preprocessor] = None): | ||
self.results = results | ||
super().__init__(preprocessor) | ||
# __statsmodelpredictor_init_end__ | ||
|
||
# __statsmodelpredictor_predict_pandas_start__ | ||
def _predict_pandas(self, data: pd.DataFrame) -> pd.DataFrame: | ||
predictions: pd.Series = self.results.predict(data) | ||
return predictions.to_frame(name="predictions") | ||
# __statsmodelpredictor_predict_pandas_end__ | ||
|
||
# __statsmodelpredictor_from_checkpoint_start__ | ||
@classmethod | ||
def from_checkpoint( | ||
cls, | ||
checkpoint: Checkpoint, | ||
filename: str, | ||
preprocessor: Optional[Preprocessor] = None, | ||
) -> Predictor: | ||
with checkpoint.as_directory() as directory: | ||
path = os.path.join(directory, filename) | ||
results = OLSResults.load(path) | ||
return cls(results, preprocessor) | ||
# __statsmodelpredictor_from_checkpoint_end__ | ||
|
||
|
||
# __statsmodelpredictor_model_start__ | ||
data: pd.DataFrame = sm.datasets.get_rdataset("Guerry", "HistData").data | ||
results = smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=data).fit() | ||
# __statsmodelpredictor_model_end__ | ||
|
||
# __statsmodelpredictor_checkpoint_start__ | ||
os.makedirs("checkpoint", exist_ok=True) | ||
results.save("checkpoint/guerry.pickle") | ||
checkpoint = Checkpoint.from_directory("checkpoint") | ||
# __statsmodelpredictor_checkpoint_end__ | ||
|
||
# __statsmodelpredictor_predict_start__ | ||
predictor = BatchPredictor.from_checkpoint( | ||
checkpoint, StatsmodelPredictor, filename="guerry.pickle" | ||
) | ||
# This is the same data we trained our model on. Don't do this in practice. | ||
dataset = ray.data.from_pandas(data) | ||
predictor.predict(dataset) | ||
# __statsmodelpredictor_predict_end__ | ||
# fmt: on |
Oops, something went wrong.