diff --git a/CHANGES.rst b/CHANGES.rst index 1ef278665..fb0bfc565 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -88,6 +88,10 @@ Minor changes by converting them to string before type inference. :pr:`623`by :user:`Leo Grinsztajn ` +* Moved the default storage location of data to the user's home folder. + :pr:`652` by :user:`Felix Lefebvre ` and + :user:`Gael Varoquaux ` + * Fixed bug when using :class:`TableVectorizer`'s `transform` method on categorical columns with missing values. :pr:`644` by :user:`Leo Grinsztajn ` diff --git a/benchmarks/bench_fuzzy_join_sparse_vs_dense.py b/benchmarks/bench_fuzzy_join_sparse_vs_dense.py index 18f99e689..370bfb06c 100644 --- a/benchmarks/bench_fuzzy_join_sparse_vs_dense.py +++ b/benchmarks/bench_fuzzy_join_sparse_vs_dense.py @@ -9,9 +9,12 @@ """ import math +from pathlib import Path +from utils import default_parser, find_result, monitor +from utils.join import evaluate, fetch_big_data +from argparse import ArgumentParser import numbers import warnings -from argparse import ArgumentParser from collections.abc import Iterable from time import perf_counter from typing import Literal @@ -29,8 +32,6 @@ ) from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import StandardScaler -from utils import default_parser, find_result, monitor -from utils.join import evaluate, fetch_big_data def _numeric_encoding( @@ -525,8 +526,12 @@ def benchmark( dataset_name: str, analyzer: Literal["char_wb", "char", "word"], ngram_range: tuple, + data_home: Path | str | None = None, + data_directory: str | None = "benchmarks_data", ): - left_table, right_table, gt = fetch_big_data(dataset_name) + left_table, right_table, gt = fetch_big_data( + dataset_name=dataset_name, data_home=data_home, data_directory=data_directory + ) start_time = perf_counter() joined_fj = fuzzy_join( diff --git a/benchmarks/utils/join.py b/benchmarks/utils/join.py index 7329c0d0f..8c2e4b4ea 100644 --- a/benchmarks/utils/join.py +++ b/benchmarks/utils/join.py @@ -1,12 +1,16 @@ import pandas as pd +from pathlib import Path from skrub.datasets._utils import get_data_dir -def get_local_data(dataset_name: str, data_directory: str = None): +def get_local_data( + dataset_name: str, + data_home: Path | str | None = None, + data_directory: str | None = None, +): """Get the path to the local datasets.""" - if data_directory is None: - data_directory = get_data_dir("benchmarks_data") + data_directory = get_data_dir(data_directory, data_home) left_path = str(data_directory) + f"/left_{dataset_name}.parquet" right_path = str(data_directory) + f"/right_{dataset_name}.parquet" gt_path = str(data_directory) + f"/gt_{dataset_name}.parquet" @@ -21,7 +25,10 @@ def get_local_data(dataset_name: str, data_directory: str = None): def fetch_data( - dataset_name: str, save: bool = True + dataset_name: str, + save: bool = True, + data_home: Path | str | None = None, + data_directory: str | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Fetch datasets from https://github.com/Yeye-He/Auto-Join/tree/master/autojoin-Benchmark # noqa @@ -33,6 +40,13 @@ def fetch_data( save: bool, default=true Wheter to save the datasets locally. + data_home: Path or str, optional + The path to the root data directory. + By default, will point to the skrub data directory. + + data_directory: str, optional + The name of the subdirectory in which data is stored. + Returns ------- left: pd.DataFrame @@ -44,7 +58,9 @@ def fetch_data( gt: pd.DataFrame Ground truth dataset. """ - left_path, right_path, gt_path, file_paths = get_local_data(dataset_name) + left_path, right_path, gt_path, file_paths = get_local_data( + dataset_name, data_home, data_directory + ) if len(file_paths) == 0: repository = "Yeye-He/Auto-Join" dataset_name = dataset_name.replace(" ", "%20") @@ -70,6 +86,8 @@ def fetch_big_data( dataset_name: str, data_type: str = "Dirty", save: bool = True, + data_home: Path | str | None = None, + data_directory: str | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Fetch datasets from https://github.com/anhaidgroup/deepmatcher/blob/master/Datasets.md # noqa @@ -85,6 +103,13 @@ def fetch_big_data( save: bool, default=true Wheter to save the datasets locally. + data_home: Path or str, optional + The path to the root data directory. + By default, will point to the skrub data directory. + + data_directory: str, optional + The name of the subdirectory in which data is stored. + Returns ------- left: pd.DataFrame @@ -97,7 +122,9 @@ def fetch_big_data( Ground truth dataset. """ link = "https://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/" - left_path, right_path, gt_path, file_paths = get_local_data(dataset_name) + left_path, right_path, gt_path, file_paths = get_local_data( + dataset_name, data_home, data_directory + ) if len(file_paths) == 0: test_idx = pd.read_csv(f"{link}/{data_type}/{dataset_name}/exp_data/test.csv") train_idx = pd.read_csv(f"{link}/{data_type}/{dataset_name}/exp_data/train.csv") diff --git a/skrub/datasets/_fetching.py b/skrub/datasets/_fetching.py index 2c3080bca..8bc9e02de 100644 --- a/skrub/datasets/_fetching.py +++ b/skrub/datasets/_fetching.py @@ -152,15 +152,16 @@ def _fetch_openml_dataset( dataset_id: int, data_directory: Path | None = None, ) -> dict[str, Any]: - """Gets a dataset from OpenML (https://www.openml.org). + """ + Gets a dataset from OpenML (https://www.openml.org). Parameters ---------- dataset_id : int The ID of the dataset to fetch. - data_directory : Path, optional - A directory to save the data to. - By default, the skrub data directory. + data_directory : pathlib.Path, optional + The directory where the dataset is stored. + By default, a subdirectory "openml" in the skrub data directory. Returns ------- @@ -176,7 +177,7 @@ def _fetch_openml_dataset( saved as a CSV file. """ if data_directory is None: - data_directory = get_data_dir() + data_directory = get_data_dir(name="openml") # Make path absolute data_directory = data_directory.resolve() @@ -242,9 +243,9 @@ def _fetch_world_bank_data( ---------- indicator_id : str The ID of the indicator's dataset to fetch. - data_directory : Path, optional - A directory to save the data to. - By default, the skrub data directory. + data_directory : pathlib.Path, optional + The directory where the dataset is stored. + By default, a subdirectory "world_bank" in the skrub data directory. Returns ------- @@ -260,7 +261,7 @@ def _fetch_world_bank_data( saved as a CSV file. """ if data_directory is None: - data_directory = get_data_dir() + data_directory = get_data_dir(name="world_bank") csv_path = (data_directory / f"{indicator_id}.csv").resolve() data_directory.mkdir(parents=True, exist_ok=True) @@ -326,8 +327,8 @@ def _fetch_figshare( figshare_id : str The ID of the dataset to fetch. data_directory : pathlib.Path, optional - A directory to save the data to. - By default, the skrub data directory. + The directory where the dataset is stored. + By default, a subdirectory "figshare" in the skrub data directory. Returns ------- @@ -347,7 +348,8 @@ def _fetch_figshare( pyarrow installed to run correctly. """ if data_directory is None: - data_directory = get_data_dir() + data_directory = get_data_dir(name="figshare") + parquet_path = (data_directory / f"figshare_{figshare_id}.parquet").resolve() data_directory.mkdir(parents=True, exist_ok=True) url = f"https://ndownloader.figshare.com/files/{figshare_id}" @@ -666,7 +668,7 @@ def fetch_employee_salaries( load_dataframe: bool = True, drop_linked: bool = True, drop_irrelevant: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the employee salaries dataset (regression), available at https://openml.org/d/42125 @@ -685,6 +687,9 @@ def fetch_employee_salaries( Drops column "full_name", which is usually irrelevant to the statistical analysis. + data_directory: pathlib.Path or str, optional + The directory where the dataset is stored. + Returns ------- DatasetAll @@ -704,7 +709,7 @@ def fetch_employee_salaries( "na_values": ["?"], }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) if load_dataframe: if drop_linked: @@ -720,7 +725,7 @@ def fetch_employee_salaries( def fetch_road_safety( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the road safety dataset (classification), available at https://openml.org/d/42803 @@ -747,14 +752,14 @@ def fetch_road_safety( "na_values": ["?"], }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) def fetch_medical_charge( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the medical charge dataset (regression), available at https://openml.org/d/42720 @@ -786,14 +791,14 @@ def fetch_medical_charge( "escapechar": "\\", }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) def fetch_midwest_survey( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the midwest survey dataset (classification), available at https://openml.org/d/42805 @@ -818,14 +823,14 @@ def fetch_midwest_survey( "escapechar": "\\", }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) def fetch_open_payments( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the open payments dataset (classification), available at https://openml.org/d/42738 @@ -852,14 +857,14 @@ def fetch_open_payments( "na_values": ["?"], }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) def fetch_traffic_violations( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the traffic violations dataset (classification), available at https://openml.org/d/42132 @@ -888,14 +893,14 @@ def fetch_traffic_violations( "na_values": ["?"], }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) def fetch_drug_directory( *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches the drug directory dataset (classification), available at https://openml.org/d/43044 @@ -921,7 +926,7 @@ def fetch_drug_directory( "escapechar": "\\", }, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) @@ -929,7 +934,7 @@ def fetch_world_bank_indicator( indicator_id: str, *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches a dataset of an indicator from the World Bank open data platform. @@ -952,7 +957,7 @@ def fetch_world_bank_indicator( dataset_id=indicator_id, target=None, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) @@ -960,7 +965,7 @@ def fetch_figshare( figshare_id: str, *, load_dataframe: bool = True, - directory: Path | str | None = None, + data_directory: Path | str | None = None, ) -> DatasetAll | DatasetInfoOnly: """Fetches a table of from figshare. @@ -978,5 +983,5 @@ def fetch_figshare( dataset_id=figshare_id, target=None, load_dataframe=load_dataframe, - data_directory=directory, + data_directory=data_directory, ) diff --git a/skrub/datasets/_utils.py b/skrub/datasets/_utils.py index aa3e57946..7b9a57ab3 100644 --- a/skrub/datasets/_utils.py +++ b/skrub/datasets/_utils.py @@ -1,8 +1,41 @@ -import os from pathlib import Path -def get_data_dir(name: str | None = None) -> Path: +def get_data_home(data_home: Path | str | None = None) -> Path: + """Returns the path of the skrub data directory. + + This folder is used by some large dataset loaders to avoid downloading the + data several times. + + By default the data directory is set to a folder named 'skrub_data' in the + user home folder. + + Alternatively, it can be set programmatically by giving an explicit folder + path. The '~' symbol is expanded to the user home folder. + + If the folder does not already exist, it is automatically created. + + Parameters + ---------- + data_home : pathlib.Path or string, optional + The path to the skrub data directory. If `None`, the default path + is `~/skrub_data`. + + Returns + ------- + data_home : pathlib.Path + The validated path to the skrub data directory. + """ + if data_home is None: + data_home = Path("~").expanduser() / "skrub_data" + else: + data_home = Path(data_home) + data_home = data_home.resolve() + data_home.mkdir(parents=True, exist_ok=True) + return data_home + + +def get_data_dir(name: str | None = None, data_home: Path | str | None = None) -> Path: """ Returns the directory in which skrub looks for data. @@ -11,13 +44,13 @@ def get_data_dir(name: str | None = None) -> Path: Parameters ---------- - name: str, optional + name : str, optional Subdirectory name. If omitted, the root data directory is returned. + data_home : pathlib.Path or str, optional + The path to skrub data directory. If `None`, the default path + is `~/skrub_data`. """ - # Note: we stick to os.path instead of pathlib.Path because - # it's easier to test, and the functionality is the same. - module_path = Path(os.path.dirname(__file__)).resolve() - data_dir = module_path / "data" + data_dir = get_data_home(data_home) if name is not None: data_dir = data_dir / name return data_dir diff --git a/skrub/datasets/tests/test_fetching.py b/skrub/datasets/tests/test_fetching.py index 992d46612..4e0429b33 100644 --- a/skrub/datasets/tests/test_fetching.py +++ b/skrub/datasets/tests/test_fetching.py @@ -32,7 +32,7 @@ def test_openml_fetching(fetch_openml_mock: mock.Mock): with TemporaryDirectory() as temp_dir: # Download the dataset without loading it in memory. dataset = _fetching.fetch_midwest_survey( - directory=temp_dir, + data_directory=temp_dir, load_dataframe=False, ) fetch_openml_mock.assert_called_once() @@ -51,7 +51,7 @@ def test_openml_fetching(fetch_openml_mock: mock.Mock): assert str(_fetching.MIDWEST_SURVEY_ID) in dataset.source # Now, load it into memory, and expect `fetch_openml` # to not be called because the dataset is already on disk. - dataset = _fetching.fetch_midwest_survey(directory=temp_dir) + dataset = _fetching.fetch_midwest_survey(data_directory=temp_dir) fetch_openml_mock.assert_not_called() fetch_openml_mock.reset_mock() assert dataset.X.shape == (2494, 28) @@ -95,7 +95,7 @@ def test_openml_datasets_calls(fetch_openml_mock: mock.Mock): (_fetching.fetch_drug_directory, _fetching.DRUG_DIRECTORY_ID), ]: try: - fetching_function(directory=temp_dir) + fetching_function(data_directory=temp_dir) except FileNotFoundError: pass fetch_openml_mock.assert_called_once() @@ -120,17 +120,17 @@ def test_fetch_world_bank_indicator(): # First, we want to purposefully test FileNotFoundError exceptions. with pytest.raises(FileNotFoundError): assert _fetching.fetch_world_bank_indicator( - indicator_id=0, directory=temp_dir + indicator_id=0, data_directory=temp_dir ) assert _fetching.fetch_world_bank_indicator( indicator_id=2**32, - directory=temp_dir, + data_directory=temp_dir, ) # Valid call returned_info = _fetching.fetch_world_bank_indicator( indicator_id=test_dataset["id"], - directory=temp_dir, + data_directory=temp_dir, ) except (ConnectionError, URLError): @@ -156,7 +156,7 @@ def test_fetch_world_bank_indicator(): # Same valid call as above disk_loaded_info = _fetching.fetch_world_bank_indicator( indicator_id=test_dataset["id"], - directory=temp_dir, + data_directory=temp_dir, ) mock_urlretrieve.assert_not_called() assert disk_loaded_info == returned_info diff --git a/skrub/datasets/tests/test_utils.py b/skrub/datasets/tests/test_utils.py index caf064605..779498db3 100644 --- a/skrub/datasets/tests/test_utils.py +++ b/skrub/datasets/tests/test_utils.py @@ -1,18 +1,48 @@ +import shutil +import tempfile from pathlib import Path -from unittest import mock +import pytest -@mock.patch("os.path.dirname") -def test_get_data_dir(mock_os_path_dirname): - """ - Tests function ``get_data_dir()``. - """ - from skrub.datasets._utils import get_data_dir +from skrub.datasets._utils import get_data_dir, get_data_home - expected_return_value_default = Path("/user/directory/data").absolute() - mock_os_path_dirname.return_value = expected_return_value_default.parent - assert get_data_dir() == expected_return_value_default +@pytest.mark.parametrize("data_home_type", ["string", "path"]) +def test_get_data_dir(data_home_type): + with tempfile.TemporaryDirectory() as target_dir: + tmp_dir = target_dir if data_home_type == "string" else Path(target_dir) - expected_return_value_custom = expected_return_value_default / "tests" - assert get_data_dir("tests") == expected_return_value_custom + # with a pre-existing folder + assert Path(target_dir).exists() + data_home = get_data_home(data_home=tmp_dir) + assert data_home == Path(target_dir).resolve() + + assert data_home.exists() + + assert ( + get_data_dir(name="tests", data_home=tmp_dir) + == data_home.resolve() / "tests" + ) + + # if the folder is missing it will be created + shutil.rmtree(tmp_dir) + assert not Path(target_dir).exists() + data_home = get_data_home(data_home=tmp_dir) + assert data_home.exists() + + +def test_get_data_home_default(): + """Test function for ``get_data_home()`` with default `data_home`.""" + # We should take care of not deleting the folder if our user + # already cached some data + user_path = Path("~").expanduser() / "skrub_data" + is_already_existing = user_path.exists() + + data_home = get_data_home(data_home=None) + assert data_home == user_path + assert data_home.exists() + + if not is_already_existing: + # Clear the folder if it was not already existing. + shutil.rmtree(user_path) + assert not user_path.exists()