Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable dense optimization in hist for distributed training. #9272

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python-package/xgboost/testing/dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests for dask shared by different test modules."""
import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from distributed import Client

import xgboost as xgb
Expand Down Expand Up @@ -52,3 +54,22 @@ def check_init_estimation(tree_method: str, client: Client) -> None:
"""Test init estimation."""
check_init_estimation_reg(tree_method, client)
check_init_estimation_clf(tree_method, client)


def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None:
"""Issue #9271, not every worker has missing value."""
assert n_workers >= 2

with client.as_current():
clf = xgb.dask.DaskXGBClassifier(tree_method=tree_method)
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
y = pd.Series([*[0] * 5000, *[1] * 5000])

X["a"][:3000:1000] = np.NaN

client.wait_for_workers(n_workers=n_workers)

clf.fit(
dd.from_pandas(X, npartitions=n_workers),
dd.from_pandas(y, npartitions=n_workers),
)
12 changes: 6 additions & 6 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ struct GPUHistMakerDevice {
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
matrix.is_dense
matrix.is_dense && !collective::IsDistributed()
};
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
return split;
Expand All @@ -299,11 +299,11 @@ struct GPUHistMakerDevice {
std::vector<bst_node_t> nidx(2 * candidates.size());
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
EvaluateSplitSharedInputs shared_inputs{
GPUTrainingParam{param}, *quantiser, feature_types, matrix.feature_segments,
matrix.gidx_fvalue_map, matrix.min_fvalue,
matrix.is_dense
};
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types,
matrix.feature_segments, matrix.gidx_fvalue_map,
matrix.min_fvalue,
// is_dense represents the local data
matrix.is_dense && !collective::IsDistributed()};
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
// Store the feature set ptrs so they dont go out of scope before the kernel is called
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class HistBuilder {

{
GradientPairPrecise grad_stat;
if (p_fmat->IsDense()) {
if (p_fmat->IsDense() && !collective::IsDistributed()) {
/**
* Specialized code for dense data: For dense data (with no missing value), the sum
* of gradient histogram is equal to snode[nid]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from dask_cuda import LocalCUDACluster

from xgboost import dask as dxgb
from xgboost.testing.dask import check_init_estimation
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
except ImportError:
pass

Expand Down Expand Up @@ -224,6 +224,12 @@ def test_boost_from_prediction(self, local_cuda_client: Client) -> None:
def test_init_estimation(self, local_cuda_client: Client) -> None:
check_init_estimation("gpu_hist", local_cuda_client)

def test_uneven_nan(self) -> None:
n_workers = 2
with LocalCUDACluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
check_uneven_nan(client, "gpu_hist", n_workers)

@pytest.mark.skipif(**tm.no_dask_cudf())
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
Expand Down
11 changes: 9 additions & 2 deletions tests/test_distributed/test_with_dask/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import pickle
import socket
import subprocess
import tempfile
from concurrent.futures import ThreadPoolExecutor
from functools import partial
Expand Down Expand Up @@ -41,7 +40,7 @@
from toolz import sliding_window # dependency of dask

from xgboost.dask import DaskDMatrix
from xgboost.testing.dask import check_init_estimation
from xgboost.testing.dask import check_init_estimation, check_uneven_nan

dask.config.set({"distributed.scheduler.allowed-failures": False})

Expand Down Expand Up @@ -2014,6 +2013,14 @@ def test_init_estimation(client: Client) -> None:
check_init_estimation("hist", client)


@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_uneven_nan(tree_method) -> None:
n_workers = 2
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
check_uneven_nan(client, tree_method, n_workers)


class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client: "Client") -> None:
Expand Down