Skip to content

Commit

Permalink
dask inplace predict.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 25, 2020
1 parent 5d0e1df commit 60bae8e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
7 changes: 5 additions & 2 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def ctypes2numpy(cptr, length, dtype):
def ctypes2cupy(cptr, length, dtype):
"""Convert a ctypes pointer array to a cupy array."""
import cupy
mem = cupy.zeros((length.value, ), dtype=dtype, order='C')
mem = cupy.zeros(length.value, dtype=dtype, order='C')
addr = ctypes.cast(cptr, ctypes.c_void_p).value
# pylint: disable=c-extension-no-member,no-member
cupy.cuda.runtime.memcpy(
Expand Down Expand Up @@ -487,6 +487,7 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
data, feature_names, feature_types = _convert_dataframes(
data, feature_names, feature_types
)
missing = np.nan if missing is None else missing

if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
Expand Down Expand Up @@ -622,6 +623,7 @@ def _init_from_dt(self, data, nthread):
def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces_str = _cudf_array_interfaces(df)
nthread = nthread if nthread is not None else 1
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
Expand Down Expand Up @@ -1560,7 +1562,8 @@ def reshape_output(predt, rows):
preds = ctypes.POINTER(ctypes.c_float)()

# once caching is supported, we can pass id(data) as cache id.

if isinstance(data, DataFrame):
data = data.values
if isinstance(data, np.ndarray):
assert data.flags.c_contiguous
arr = np.array(data.reshape(data.size), copy=False,
Expand Down
46 changes: 33 additions & 13 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .compat import sparse, scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
from .compat import lazy_isinstance

from .core import DMatrix, Booster, _expect
from .training import train as worker_train
Expand Down Expand Up @@ -98,6 +99,9 @@ def concat(value):
return pandas_concat(value, axis=0)
if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)):
return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
import cupy # pylint: disable=import-error
return cupy.concatenate(value, axis=0)
return dd.multi.concat(list(value), axis=0)


Expand Down Expand Up @@ -426,7 +430,6 @@ def dispatched_train(worker_addr):
local_param['n_jobs'] is not None and \
local_param['n_jobs'] != worker.nthreads:
msg = '`n_jobs` is specified. ' + msg
print('local_param[n_jobs]', local_param['n_jobs'])
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
Expand Down Expand Up @@ -502,11 +505,10 @@ def mapped_predict(partition, is_df):
).result()
return predictions
if isinstance(data, dd.DataFrame):
import dask
predictions = client.submit(
dd.map_partitions,
mapped_predict, data, True,
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
meta=dd.utils.make_meta({'prediction': 'f4'})
).result()
return predictions.iloc[:, 0]

Expand Down Expand Up @@ -607,27 +609,45 @@ def inplace_predict(client, model, data,
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))

def dispatched_predict(data):
def mapped_predict(data, is_df):
worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict(
data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
if is_df:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf
# There's an error with cudf saying `concat_cudf` got an
# expected argument `ignore_index`. So the test here is just
# place holder. So this is not working yet.
prediction = cudf.DataFrame({'prediction': prediction},
dtype=numpy.float32)
else:
# If it's from pandas, the partition is a numpy array
prediction = DataFrame(prediction, columns=['prediction'],
dtype=numpy.float32)
return prediction

msg = 'Only dask array is supported for inplace prediction'
assert isinstance(data, da.Array), msg

def map_blocks():
predictions = da.map_blocks(dispatched_predict, data, drop_axis=1)
if isinstance(data, da.Array):
predictions = client.submit(
da.map_blocks,
mapped_predict, data, False, drop_axis=1,
dtype=numpy.float32
).result()
return predictions

predictions = client.submit(map_blocks)
import dask
return dask.delayed(predictions).compute()
if isinstance(data, dd.DataFrame):
predictions = client.submit(
dd.map_partitions,
mapped_predict, data, True,
meta=dd.utils.make_meta({'prediction': 'f4'})
).result()
return predictions.iloc[:, 0]


def _evaluation_matrices(client, validation_set, sample_weights, missing):
Expand Down
38 changes: 38 additions & 0 deletions tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import numpy as np
import unittest
import xgboost

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
Expand Down Expand Up @@ -29,6 +30,7 @@ class TestDistributedGPU(unittest.TestCase):
def test_dask_dataframe(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
import cupy
X, y = generate_array()

X = dd.from_dask_array(X)
Expand All @@ -49,6 +51,42 @@ def test_dask_dataframe(self):
predictions = dxgb.predict(client, out, dtrain).compute()
assert isinstance(predictions, np.ndarray)

# There's an error with cudf saying `concat_cudf` got an
# expected argument `ignore_index`. So the test here is just
# place holder.

# series_predictions = dxgb.inplace_predict(client, out, X)
# assert isinstance(series_predictions, dd.Series)

single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))
cupy.testing.assert_allclose(single_node, predictions)

@pytest.mark.skipif(**tm.no_cupy())
def test_dask_array(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
import cupy
X, y = generate_array()

X = X.map_blocks(cupy.asarray)
y = y.map_blocks(cupy.asarray)
dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
dtrain=dtrain,
evals=[(dtrain, 'X')],
num_boost_round=2)
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
inplace_predictions = dxgb.inplace_predict(
client, out, X).compute()
single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))
np.testing.assert_allclose(single_node, from_dmatrix)
cupy.testing.assert_allclose(
cupy.array(single_node),
inplace_predictions)


@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
Expand Down
6 changes: 6 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,14 @@ def test_from_dask_dataframe():
from_df = prediction.compute()

assert isinstance(prediction, dd.Series)
assert np.all(prediction.compute().values == from_dmatrix)
assert np.all(from_dmatrix == from_df.to_numpy())

series_predictions = xgb.dask.inplace_predict(client, booster, X)
assert isinstance(series_predictions, dd.Series)
np.testing.assert_allclose(series_predictions.compute().values,
from_dmatrix)


def test_from_dask_array():
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
Expand Down

0 comments on commit 60bae8e

Please sign in to comment.