Skip to content

Commit

Permalink
Fix global config default value. (#6470)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 5, 2020
1 parent d6386e4 commit 703c2d0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/global_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace xgboost {
class Json;

struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
int verbosity;
int verbosity { 1 };
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
DMLC_DECLARE_FIELD(verbosity)
.set_range(0, 3)
Expand Down
27 changes: 16 additions & 11 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):


async def _train_async(client,
global_config,
params,
dtrain: DaskDMatrix,
*args,
Expand All @@ -639,15 +640,14 @@ async def _train_async(client,

workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
_global_config = config.get_config()

def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed.get_worker()
with RabitContext(rabit_args), config.config_context(**_global_config):
with RabitContext(rabit_args), config.config_context(**global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_evals = []
if evals_ref:
Expand Down Expand Up @@ -735,8 +735,11 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
'''
_assert_dask_support()
client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or
# process.
global_config = config.get_config()
return client.sync(
_train_async, client, params, dtrain=dtrain, *args, evals=evals,
_train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals,
early_stopping_rounds=early_stopping_rounds, **kwargs)


Expand All @@ -760,7 +763,7 @@ async def _direct_predict_impl(client, data, predict_fn):


# pylint: disable=too-many-statements
async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
async def _predict_async(client, global_config, model, data, missing=numpy.nan, **kwargs):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand All @@ -771,11 +774,9 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))

_global_config = config.get_config()

def mapped_predict(partition, is_df):
worker = distributed.get_worker()
with config.config_context(**_global_config):
with config.config_context(**global_config):
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
Expand All @@ -801,7 +802,7 @@ def mapped_predict(partition, is_df):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
with config.config_context(**_global_config):
with config.config_context(**global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
Expand Down Expand Up @@ -907,11 +908,12 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
'''
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(_predict_async, client, model, data,
global_config = config.get_config()
return client.sync(_predict_async, client, global_config, model, data,
missing=missing, **kwargs)


async def _inplace_predict_async(client, model, data,
async def _inplace_predict_async(client, global_config, model, data,
iteration_range=(0, 0),
predict_type='value',
missing=numpy.nan):
Expand All @@ -927,6 +929,7 @@ async def _inplace_predict_async(client, model, data,

def mapped_predict(data, is_df):
worker = distributed.get_worker()
config.set_config(**global_config)
booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict(
data,
Expand Down Expand Up @@ -976,7 +979,9 @@ def inplace_predict(client, model, data,
'''
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(_inplace_predict_async, client, model=model, data=data,
global_config = config.get_config()
return client.sync(_inplace_predict_async, client, global_config, model=model,
data=data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
Expand Down

0 comments on commit 703c2d0

Please sign in to comment.