Skip to content

Commit

Permalink
[AIR] Improve to_air_checkpoint with path (ray-project#26532)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <[email protected]>
  • Loading branch information
Yard1 authored and xwjiang2010 committed Jul 19, 2022
1 parent 2842ed1 commit 9588b14
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 18 deletions.
22 changes: 18 additions & 4 deletions python/ray/train/lightgbm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,33 @@

@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: lightgbm.Booster,
*,
path: os.PathLike,
preprocessor: Optional["Preprocessor"] = None,
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.
Example:
.. code-block:: python
import lightgbm
import tempfile
from ray.train.lightgbm import to_air_checkpoint, LightGBMPredictor
bst = lightgbm.Booster()
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(booster=bst, path=tmpdir)
predictor = LightGBMPredictor.from_checkpoint(checkpoint)
Args:
path: The directory path where model and preprocessor steps are stored to.
booster: A pretrained lightgbm model.
path: The directory where the checkpoint will be stored to.
preprocessor: A fitted preprocessor. The preprocessing logic will
be applied to serve/inference.
be applied to the inputs for serving/inference.
Returns:
A Ray Air checkpoint.
A Ray AIR checkpoint.
"""
booster.save_model(os.path.join(path, MODEL_KEY))

Expand Down
20 changes: 17 additions & 3 deletions python/ray/train/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,31 @@

@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
estimator: BaseEstimator,
*,
path: os.PathLike,
preprocessor: Optional["Preprocessor"] = None,
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.
Example:
.. code-block:: python
import tempfile
from sklearn.ensemble import RandomForestClassifier
from ray.train.sklearn import to_air_checkpoint, SklearnPredictor
est = RandomForestClassifier()
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(estimator=est, path=tmpdir)
predictor = SklearnPredictor.from_checkpoint(checkpoint)
Args:
path: The directory path where model and preprocessor steps are stored to.
estimator: A pretrained model.
path: The directory where the checkpoint will be stored to.
preprocessor: A fitted preprocessor. The preprocessing logic will
be applied to serve/inference.
be applied to the inputs for serving/inference.
Returns:
A Ray Air checkpoint.
"""
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
model: keras.Model, *, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.
Args:
model: A pretrained model.
preprocessor: A fitted preprocessor. The preprocessing logic will
be applied to serve/inference.
be applied to the inputs for serving/inference.
Returns:
A Ray Air checkpoint.
"""
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_lightgbm_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_predict_feature_columns_pandas():

def test_predict_no_preprocessor_no_training():
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(tmpdir, booster=model)
checkpoint = to_air_checkpoint(booster=model, path=tmpdir)
predictor = LightGBMPredictor.from_checkpoint(checkpoint)

data_batch = np.array([[1, 2], [3, 4], [5, 6]])
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_sklearn_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_batch_prediction_with_set_cpus(ray_start_4_cpus):

def test_sklearn_predictor_no_training():
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(path=tmpdir, estimator=model)
checkpoint = to_air_checkpoint(estimator=model, path=tmpdir)
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, SklearnPredictor)
test_dataset = ray.data.from_pandas(
pd.DataFrame(dummy_data, columns=["A", "B"])
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_xgboost_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_predict_feature_columns_pandas():

def test_predict_no_preprocessor_no_training():
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(tmpdir, booster=model)
checkpoint = to_air_checkpoint(booster=model, path=tmpdir)
predictor = XGBoostPredictor.from_checkpoint(checkpoint)

data_batch = np.array([[1, 2], [3, 4], [5, 6]])
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
model: torch.nn.Module, *, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.
Args:
model: A pretrained model.
preprocessor: A fitted preprocessor. The preprocessing logic will
be applied to serve/inference.
be applied to the inputs for serving/inference.
Returns:
A Ray Air checkpoint.
"""
Expand Down
23 changes: 19 additions & 4 deletions python/ray/train/xgboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,34 @@

@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: xgboost.Booster,
*,
path: os.PathLike,
preprocessor: Optional["Preprocessor"] = None,
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.
Example:
.. code-block:: python
import xgboost
import tempfile
from ray.train.xgboost import to_air_checkpoint, XGBoostPredictor
bst = xgboost.Booster()
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint = to_air_checkpoint(booster=bst, path=tmpdir)
predictor = XGBoostPredictor.from_checkpoint(checkpoint)
Args:
path: The directory path where model and preprocessor steps are stored to.
booster: A pretrained xgboost model.
path: The directory where the checkpoint will be stored to.
preprocessor: A fitted preprocessor. The preprocessing logic will
be applied to serve/inference.
be applied to the inputs for serving/inference.
Returns:
A Ray Air checkpoint.
A Ray AIR checkpoint.
"""
booster.save_model(os.path.join(path, MODEL_KEY))

Expand Down

0 comments on commit 9588b14

Please sign in to comment.