Skip to content

Commit

Permalink
[doc] [dask] Add example on early stopping with Dask (#6501)
Browse files Browse the repository at this point in the history
Co-authored-by: fis <[email protected]>
  • Loading branch information
jameslamb and trivialfis committed Dec 15, 2020
1 parent 8139849 commit 1e2c3ad
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,84 @@ actual computation will return a coroutine and hence require awaiting:
# Use `client.compute` instead of the `compute` method from dask collection
print(await client.compute(prediction))
*****************************
Evaluation and Early Stopping
*****************************

.. versionadded:: 1.3.0

The Dask interface allows the use of validation sets that are stored in distributed collections (Dask DataFrame or Dask Array). These can be used for evaluation and early stopping.

To enable early stopping, pass one or more validation sets containing ``DaskDMatrix`` objects.

.. code-block:: python
import dask.array as da
import xgboost as xgb
num_rows = 1e6
num_features = 100
num_partitions = 10
rows_per_chunk = num_rows / num_partitions
data = da.random.random(
size=(num_rows, num_features),
chunks=(rows_per_chunk, num_features)
)
labels = da.random.random(
size=(num_rows, 1),
chunks=(rows_per_chunk, 1)
)
X_eval = da.random.random(
size=(num_rows, num_features),
chunks=(rows_per_chunk, num_features)
)
y_eval = da.random.random(
size=(num_rows, 1),
chunks=(rows_per_chunk, 1)
)
dtrain = xgb.dask.DaskDMatrix(
client=client,
data=data,
label=labels
)
dvalid = xgb.dask.DaskDMatrix(
client=client,
data=X_eval,
label=y_eval
)
result = xgb.dask.train(
client=client,
params={
"objective": "reg:squarederror",
},
dtrain=dtrain,
num_boost_round=10,
evals=[(dvalid, "valid1")],
early_stopping_rounds=3
)
When validation sets are provided to ``xgb.dask.train()`` in this way, the model object returned by ``xgb.dask.train()`` contains a history of evaluation metrics for each validation set, across all boosting rounds.

.. code-block:: python
print(result["history"])
# {'valid1': OrderedDict([('rmse', [0.28857, 0.28858, 0.288592, 0.288598])])}
If early stopping is enabled by also passing ``early_stopping_rounds``, you can check the best iteration in the returned booster.

.. code-block:: python
booster = result["booster"]
print(booster.best_iteration)
best_model = booster[: booster.best_iteration]
*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************
Expand Down

0 comments on commit 1e2c3ad

Please sign in to comment.