diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 00079d945c5d..103a5450291e 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -131,7 +131,7 @@ def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs) return X, y, w, g_rle, dX, dy, dw, dg -def _create_data(objective, n_samples=100, output='array', chunk_size=50): +def _create_data(objective, n_samples=100, output='array', chunk_size=50, **kwargs): if objective.endswith('classification'): if objective == 'binary-classification': centers = [[-4, -4], [4, 4]] @@ -142,6 +142,13 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50): X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42) elif objective == 'regression': X, y = make_regression(n_samples=n_samples, random_state=42) + elif objective == 'ranking': + return _create_ranking_data( + n_samples=n_samples, + output=output, + chunk_size=chunk_size, + **kwargs + ) else: raise ValueError("Unknown objective '%s'" % objective) rnd = np.random.RandomState(42) @@ -183,7 +190,7 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50): else: raise ValueError("Unknown output type '%s'" % output) - return X, y, weights, dX, dy, dw + return X, y, weights, None, dX, dy, dw, None def _r2_score(dy_true, dy_pred): @@ -225,7 +232,7 @@ def _unpickle(filepath, serializer): @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) def test_classifier(output, task, client): - X, y, w, dX, dy, dw = _create_data( + X, y, w, _, dX, dy, dw, _ = _create_data( objective=task, output=output ) @@ -291,7 +298,7 @@ def test_classifier(output, task, client): @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) def test_classifier_pred_contrib(output, task, client): - X, y, w, dX, dy, dw = _create_data( + X, y, w, _, dX, dy, dw, _ = _create_data( objective=task, output=output ) @@ -369,7 +376,7 @@ def test_find_random_open_port(client): def test_training_does_not_fail_on_port_conflicts(client): - _, _, _, dX, dy, dw = _create_data('binary-classification', output='array') + _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') lightgbm_default_port = 12400 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -393,7 +400,7 @@ def test_training_does_not_fail_on_port_conflicts(client): @pytest.mark.parametrize('output', data_output) def test_regressor(output, client): - X, y, w, dX, dy, dw = _create_data( + X, y, w, _, dX, dy, dw, _ = _create_data( objective='regression', output=output ) @@ -468,7 +475,7 @@ def test_regressor(output, client): @pytest.mark.parametrize('output', data_output) def test_regressor_pred_contrib(output, client): - X, y, w, dX, dy, dw = _create_data( + X, y, w, _, dX, dy, dw, _ = _create_data( objective='regression', output=output ) @@ -518,7 +525,7 @@ def test_regressor_pred_contrib(output, client): @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('alpha', [.1, .5, .9]) def test_regressor_quantile(output, client, alpha): - X, y, w, dX, dy, dw = _create_data( + X, y, w, _, dX, dy, dw, _ = _create_data( objective='regression', output=output ) @@ -567,18 +574,19 @@ def test_regressor_quantile(output, client, alpha): @pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) @pytest.mark.parametrize('group', [None, group_sizes]) def test_ranker(output, client, group): - if output == 'dataframe-with-categorical': - X, y, w, g, dX, dy, dw, dg = _create_ranking_data( + X, y, w, g, dX, dy, dw, dg = _create_data( + objective='ranking', output=output, group=group, n_features=1, n_informative=1 ) else: - X, y, w, g, dX, dy, dw, dg = _create_ranking_data( + X, y, w, g, dX, dy, dw, dg = _create_data( + objective='ranking', output=output, - group=group, + group=group ) # rebalance small dask.Array dataset for better performance. @@ -650,17 +658,11 @@ def test_ranker(output, client, group): @pytest.mark.parametrize('task', tasks) def test_training_works_if_client_not_provided_or_set_after_construction(task, client): - if task == 'ranking': - _, _, _, _, dX, dy, _, dg = _create_ranking_data( - output='array', - group=None - ) - else: - _, _, _, dX, dy, _ = _create_data( - objective=task, - output='array', - ) - dg = None + _, _, _, _, dX, dy, _, dg = _create_data( + objective=task, + output='array', + group=None + ) model_factory = task_to_dask_factory[task] params = { @@ -723,182 +725,166 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c @pytest.mark.parametrize('set_client', [True, False]) def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path): - with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1: - with Client(cluster1) as client1: + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1, Client(cluster1) as client1: + # data on cluster1 + X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_data( + objective=task, + output='array', + group=None + ) + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2, Client(cluster2) as client2: + # create identical data on cluster2 + X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_data( + objective=task, + output='array', + group=None + ) - # data on cluster1 - if task == 'ranking': - X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data( - output='array', - group=None - ) + model_factory = task_to_dask_factory[task] + + params = { + "time_out": 5, + "n_estimators": 1, + "num_leaves": 2 + } + + # at this point, the result of default_client() is client2 since it was the most recently + # created. So setting client to client1 here to test that you can select a non-default client + assert default_client() == client2 + if set_client: + params.update({"client": client1}) + + # unfitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + dask_model = model_factory(**params) + local_model = dask_model.to_local() + if set_client: + assert dask_model.client == client1 else: - X_1, _, _, dX_1, dy_1, _ = _create_data( - objective=task, - output='array', - ) - dg_1 = None - - with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2: - with Client(cluster2) as client2: - - # create identical data on cluster2 - if task == 'ranking': - X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data( - output='array', - group=None - ) - else: - X_2, _, _, dX_2, dy_2, _ = _create_data( - objective=task, - output='array', - ) - dg_2 = None - - model_factory = task_to_dask_factory[task] - - params = { - "time_out": 5, - "n_estimators": 1, - "num_leaves": 2 - } - - # at this point, the result of default_client() is client2 since it was the most recently - # created. So setting client to client1 here to test that you can select a non-default client - assert default_client() == client2 - if set_client: - params.update({"client": client1}) - - # unfitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - dask_model = model_factory(**params) - local_model = dask_model.to_local() - if set_client: - assert dask_model.client == client1 - else: - assert dask_model.client is None - - with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): - dask_model.client_ - - assert "client" not in local_model.get_params() - assert getattr(local_model, "client", None) is None - - tmp_file = str(tmp_path / "model-1.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file, - serializer=serializer - ) - model_from_disk = _unpickle( - filepath=tmp_file, - serializer=serializer - ) - - local_tmp_file = str(tmp_path / "local-model-1.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file, - serializer=serializer - ) - local_model_from_disk = _unpickle( - filepath=local_tmp_file, - serializer=serializer - ) - - assert model_from_disk.client is None - - if set_client: - assert dask_model.client == client1 - else: - assert dask_model.client is None - - with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): - dask_model.client_ - - # client will always be None after unpickling - if set_client: - from_disk_params = model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert model_from_disk.get_params() == dask_model.get_params() - assert local_model_from_disk.get_params() == local_model.get_params() - - # fitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - if set_client: - dask_model.fit(dX_1, dy_1, group=dg_1) - else: - dask_model.fit(dX_2, dy_2, group=dg_2) - local_model = dask_model.to_local() - - assert "client" not in local_model.get_params() - with pytest.raises(AttributeError): - local_model.client - local_model.client_ - - tmp_file2 = str(tmp_path / "model-2.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file2, - serializer=serializer - ) - fitted_model_from_disk = _unpickle( - filepath=tmp_file2, - serializer=serializer - ) - - local_tmp_file2 = str(tmp_path / "local-model-2.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file2, - serializer=serializer - ) - local_fitted_model_from_disk = _unpickle( - filepath=local_tmp_file2, - serializer=serializer - ) - - if set_client: - assert dask_model.client == client1 - assert dask_model.client_ == client1 - else: - assert dask_model.client is None - assert dask_model.client_ == default_client() - assert dask_model.client_ == client2 - - assert isinstance(fitted_model_from_disk, model_factory) - assert fitted_model_from_disk.client is None - assert fitted_model_from_disk.client_ == default_client() - assert fitted_model_from_disk.client_ == client2 - - # client will always be None after unpickling - if set_client: - from_disk_params = fitted_model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert fitted_model_from_disk.get_params() == dask_model.get_params() - assert local_fitted_model_from_disk.get_params() == local_model.get_params() - - if set_client: - preds_orig = dask_model.predict(dX_1).compute() - preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute() - preds_orig_local = local_model.predict(X_1) - preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1) - else: - preds_orig = dask_model.predict(dX_2).compute() - preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute() - preds_orig_local = local_model.predict(X_2) - preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2) - - assert_eq(preds_orig, preds_loaded_model) - assert_eq(preds_orig_local, preds_loaded_model_local) + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + + tmp_file = str(tmp_path / "model-1.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file, + serializer=serializer + ) + model_from_disk = _unpickle( + filepath=tmp_file, + serializer=serializer + ) + + local_tmp_file = str(tmp_path / "local-model-1.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file, + serializer=serializer + ) + local_model_from_disk = _unpickle( + filepath=local_tmp_file, + serializer=serializer + ) + + assert model_from_disk.client is None + + if set_client: + assert dask_model.client == client1 + else: + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() + assert local_model_from_disk.get_params() == local_model.get_params() + + # fitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + if set_client: + dask_model.fit(dX_1, dy_1, group=dg_1) + else: + dask_model.fit(dX_2, dy_2, group=dg_2) + local_model = dask_model.to_local() + + assert "client" not in local_model.get_params() + with pytest.raises(AttributeError): + local_model.client + local_model.client_ + + tmp_file2 = str(tmp_path / "model-2.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file2, + serializer=serializer + ) + fitted_model_from_disk = _unpickle( + filepath=tmp_file2, + serializer=serializer + ) + + local_tmp_file2 = str(tmp_path / "local-model-2.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file2, + serializer=serializer + ) + local_fitted_model_from_disk = _unpickle( + filepath=local_tmp_file2, + serializer=serializer + ) + + if set_client: + assert dask_model.client == client1 + assert dask_model.client_ == client1 + else: + assert dask_model.client is None + assert dask_model.client_ == default_client() + assert dask_model.client_ == client2 + + assert isinstance(fitted_model_from_disk, model_factory) + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == default_client() + assert fitted_model_from_disk.client_ == client2 + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() + + if set_client: + preds_orig = dask_model.predict(dX_1).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute() + preds_orig_local = local_model.predict(X_1) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1) + else: + preds_orig = dask_model.predict(dX_2).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute() + preds_orig_local = local_model.predict(X_2) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2) + + assert_eq(preds_orig, preds_loaded_model) + assert_eq(preds_orig_local, preds_loaded_model_local) def test_warns_and_continues_on_unrecognized_tree_learner(client): @@ -971,18 +957,11 @@ def collection_to_single_partition(collection): return collection.rechunk(*collection.shape) return collection.repartition(npartitions=1) - if task == 'ranking': - X, y, w, g, dX, dy, dw, dg = _create_ranking_data( - output=output, - group=None - ) - else: - X, y, w, dX, dy, dw = _create_data( - objective=task, - output=output - ) - g = None - dg = None + X, y, w, g, dX, dy, dw, dg = _create_data( + objective=task, + output=output, + group=None + ) dask_model_factory = task_to_dask_factory[task] local_model_factory = task_to_local_factory[task] @@ -1026,19 +1005,12 @@ def test_network_params_not_required_but_respected_if_given(client, task, output client.wait_for_workers(2) - if task == 'ranking': - _, _, _, _, dX, dy, _, dg = _create_ranking_data( - output=output, - group=None, - chunk_size=10, - ) - else: - _, _, _, dX, dy, _ = _create_data( - objective=task, - output=output, - chunk_size=10, - ) - dg = None + _, _, _, _, dX, dy, _, dg = _create_data( + objective=task, + output=output, + chunk_size=10, + group=None + ) dask_model_factory = task_to_dask_factory[task] @@ -1097,19 +1069,12 @@ def test_machines_should_be_used_if_provided(task, output): pytest.skip('LGBMRanker is not currently tested on sparse matrices') with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: - if task == 'ranking': - _, _, _, _, dX, dy, _, dg = _create_ranking_data( - output=output, - group=None, - chunk_size=10, - ) - else: - _, _, _, dX, dy, _ = _create_data( - objective=task, - output=output, - chunk_size=10, - ) - dg = None + _, _, _, _, dX, dy, _, dg = _create_data( + objective=task, + output=output, + chunk_size=10, + group=None + ) dask_model_factory = task_to_dask_factory[task] @@ -1205,17 +1170,11 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( task, client, ): - if task == 'ranking': - _, _, _, _, dX, dy, dw, dg = _create_ranking_data( - output='dataframe', - group=None - ) - else: - _, _, _, dX, dy, dw = _create_data( - objective=task, - output='dataframe', - ) - dg = None + _, _, _, _, dX, dy, dw, dg = _create_data( + objective=task, + output='dataframe', + group=None + ) model_factory = task_to_dask_factory[task] @@ -1242,17 +1201,11 @@ def test_init_score(task, output, client): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') - if task == 'ranking': - _, _, _, _, dX, dy, dw, dg = _create_ranking_data( - output=output, - group=None - ) - else: - _, _, _, dX, dy, dw = _create_data( - objective=task, - output=output, - ) - dg = None + _, _, _, _, dX, dy, dw, dg = _create_data( + objective=task, + output=output, + group=None + ) model_factory = task_to_dask_factory[task]