diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index 9b39279cc712..e8abe74dfcf3 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -56,6 +56,12 @@ def get_world_size(): return ret +def is_distributed(): + '''If rabit is distributed.''' + is_dist = _LIB.RabitIsDistributed() + return is_dist + + def tracker_print(msg): """Print message to the tracker. @@ -143,6 +149,14 @@ def broadcast(data, root): } +class Op: # pylint: disable=too-few-public-methods + '''Supported operations for rabit.''' + MAX = 0 + MIN = 1 + SUM = 2 + OR = 3 + + def allreduce(data, op, prepare_fun=None): """Perform allreduce, return the result. diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 9cff5752539b..bb316810fb30 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -1,7 +1,8 @@ -import time - from xgboost import RabitTracker import xgboost as xgb +import pytest +import testing as tm +import numpy as np def test_rabit_tracker(): @@ -15,3 +16,39 @@ def test_rabit_tracker(): ret = xgb.rabit.broadcast('test1234', 0) assert str(ret) == 'test1234' xgb.rabit.finalize() + + +def run_rabit_ops(client, n_workers): + from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers + from xgboost import rabit + + workers = list(_get_client_workers(client).keys()) + rabit_args = _get_rabit_args(workers, client) + assert not rabit.is_distributed() + + def local_test(worker_id): + with RabitContext(rabit_args): + a = 1 + assert rabit.is_distributed() + a = np.array([a]) + reduced = rabit.allreduce(a, rabit.Op.SUM) + assert reduced[0] == n_workers + + worker_id = np.array([worker_id]) + reduced = rabit.allreduce(worker_id, rabit.Op.MAX) + assert reduced == n_workers - 1 + + return 1 + + futures = client.map(local_test, range(len(workers)), workers=workers) + results = client.gather(futures) + assert sum(results) == n_workers + + +@pytest.mark.skipif(**tm.no_dask()) +def test_rabit_ops(): + from distributed import Client, LocalCluster + n_workers = 3 + with LocalCluster(n_workers=n_workers) as cluster: + with Client(cluster) as client: + run_rabit_ops(client, n_workers)