Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python binding for rabit ops. #5743

Merged
merged 3 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def get_world_size():
return ret


def is_distributed():
'''If rabit is distributed.'''
is_dist = _LIB.RabitIsDistributed()
return is_dist


def tracker_print(msg):
"""Print message to the tracker.

Expand Down Expand Up @@ -143,6 +149,14 @@ def broadcast(data, root):
}


class Op: # pylint: disable=too-few-public-methods
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.

Expand Down
41 changes: 39 additions & 2 deletions tests/python/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time

from xgboost import RabitTracker
import xgboost as xgb
import pytest
import testing as tm
import numpy as np


def test_rabit_tracker():
Expand All @@ -15,3 +16,39 @@ def test_rabit_tracker():
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
xgb.rabit.finalize()


def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
from xgboost import rabit

workers = list(_get_client_workers(client).keys())
rabit_args = _get_rabit_args(workers, client)
assert not rabit.is_distributed()

def local_test(worker_id):
with RabitContext(rabit_args):
a = 1
assert rabit.is_distributed()
a = np.array([a])
reduced = rabit.allreduce(a, rabit.Op.SUM)
assert reduced[0] == n_workers

worker_id = np.array([worker_id])
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
assert reduced == n_workers - 1

return 1

futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
assert sum(results) == n_workers


@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)