Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Accept other inputs for prediction. #5428

Merged
merged 4 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def concat(value):

def _xgb_get_client(client):
'''Simple wrapper around testing None.'''
if not isinstance(client, (type(get_client()), type(None))):
raise TypeError(
_expect([type(get_client()), type(None)], type(client)))
ret = get_client() if client is None else client
return ret

Expand All @@ -112,12 +115,6 @@ def _get_client_workers(client):
return workers


def _assert_client(client):
if not isinstance(client, (type(get_client()), type(None))):
raise TypeError(
_expect([type(get_client()), type(None)], type(client)))


class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
Expand Down Expand Up @@ -155,7 +152,7 @@ def __init__(self,
feature_names=None,
feature_types=None):
_assert_dask_support()
_assert_client(client)
client = _xgb_get_client(client)

self.feature_names = feature_names
self.feature_types = feature_types
Expand All @@ -177,7 +174,6 @@ def __init__(self,
self.has_label = label is not None
self.has_weights = weight is not None

client = _xgb_get_client(client)
client.sync(self.map_local_data, client, data, label, weight)

async def map_local_data(self, client, data, label=None, weights=None):
Expand Down Expand Up @@ -391,13 +387,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):

'''
_assert_dask_support()
_assert_client(client)
client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')

client = _xgb_get_client(client)
workers = list(_get_client_workers(client).keys())

rabit_args = _get_rabit_args(workers, client)
Expand Down Expand Up @@ -452,7 +447,7 @@ def dispatched_train(worker_addr):
return list(filter(lambda ret: ret is not None, results))[0]


def predict(client, model, data, *args):
def predict(client, model, data, *args, missing=numpy.nan):
'''Run prediction with a trained booster.

.. note::
Expand All @@ -466,32 +461,55 @@ def predict(client, model, data, *args):
returned from dask if it's set to None.
model: A Booster or a dictionary returned by `xgboost.dask.train`.
The trained model.
data: DaskDMatrix
data: DaskDMatrix/dask.dataframe.DataFrame/dask.array.Array
Input data used for prediction.
missing: float
Used when input data is not DaskDMatrix. Specify the value
considered as missing.

Returns
-------
prediction: dask.array.Array
prediction: dask.array.Array/dask.dataframe.Series

'''
_assert_dask_support()
_assert_client(client)
client = _xgb_get_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))

if not isinstance(data, DaskDMatrix):
raise TypeError(_expect([DaskDMatrix], type(data)))

def mapped_predict(partition, is_df):
worker = distributed_get_worker()
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, *args, validate_features=False)
if is_df:
predt = DataFrame(predt, columns=['prediction'])
return predt

if isinstance(data, da.Array):
predictions = client.submit(
da.map_blocks,
mapped_predict, data, False, drop_axis=1,
dtype=numpy.float32
).result()
return predictions
if isinstance(data, dd.DataFrame):
import dask
predictions = client.submit(
dd.map_partitions,
mapped_predict, data, True,
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
).result()
return predictions.iloc[:, 0]

# Prediction on dask DMatrix.
worker_map = data.worker_map
client = _xgb_get_client(client)

missing = data.missing
feature_names = data.feature_names
feature_types = data.feature_types

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
Expand All @@ -502,9 +520,9 @@ def dispatched_predict(worker_id):
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
feature_names=data.feature_names,
feature_types=data.feature_types,
missing=data.missing,
nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
Expand All @@ -520,8 +538,7 @@ def dispatched_get_shape(worker_id):
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []
for part, order in list_of_parts:
s = part.shape
shapes.append((s, order))
shapes.append((part.shape, order))
return shapes

def map_function(func):
Expand Down
14 changes: 13 additions & 1 deletion tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def test_from_dask_dataframe():
xgb.dask.train(
client, {}, dtrain, num_boost_round=2, evals_result={})
# force prediction to be computed
prediction = prediction.compute()
from_dmatrix = prediction.compute()

prediction = xgb.dask.predict(client, model=booster, data=X)
from_df = prediction.compute()

assert isinstance(prediction, dd.Series)
assert np.all(from_dmatrix == from_df.to_numpy())


def test_from_dask_array():
Expand All @@ -84,6 +90,12 @@ def test_from_dask_array():
config = json.loads(booster.save_config())
assert int(config['learner']['generic_param']['nthread']) == 5

from_arr = xgb.dask.predict(
client, model=booster, data=X)

assert isinstance(from_arr, da.Array)
assert np.all(single_node_predt == from_arr.compute())


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
Expand Down