Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move default location of data to the user home folder #652

Merged
merged 40 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3e00a8c
Change default location of data_home
flefebv Jul 18, 2023
b786dc9
Removed the `clear_data_home` function
flefebv Jul 18, 2023
9feaef7
Add changes to CHANGES.rst
flefebv Jul 18, 2023
4ea3570
Apply suggestions from code review
flefebv Jul 18, 2023
5b22b55
update _utils.py and test_utils.py to use pathlib instead os.path
flefebv Jul 19, 2023
3f6b8b1
Propagate changes
flefebv Jul 19, 2023
d3254fd
Update skrub/datasets/tests/test_utils.py
flefebv Jul 19, 2023
2e63488
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
16189e5
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
78ba357
Update skrub/datasets/tests/test_utils.py
flefebv Jul 19, 2023
c686850
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
01e8227
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
eeeb2ca
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
c0251bc
Update skrub/datasets/_utils.py
flefebv Jul 19, 2023
ccb9f56
Update skrub/datasets/tests/test_utils.py
flefebv Jul 19, 2023
79bac36
Update skrub/datasets/tests/test_utils.py
flefebv Jul 19, 2023
56bbb15
allow data_home to be a string
flefebv Jul 19, 2023
c538354
Improved test for `get_data_home`
flefebv Jul 19, 2023
45afc07
Apply suggestions from code review
LilianBoulard Jul 19, 2023
7fd24e8
Simplified test for `get_data_home`
flefebv Jul 19, 2023
435b2e7
change `data_home` for `data_directory` when needed
flefebv Jul 19, 2023
9a0722d
Merge branch 'main' of github.com:flefebv/skrub into data_home
flefebv Jul 19, 2023
0b031b2
Fixes bug with user directory
flefebv Jul 19, 2023
64d46c1
Correct type hint
flefebv Jul 19, 2023
19c661f
Update skrub/datasets/_utils.py
GaelVaroquaux Jul 20, 2023
7886313
Apply suggestions from code review
LilianBoulard Jul 20, 2023
3fafe8d
Merge branch 'main' into pr_652
GaelVaroquaux Jul 21, 2023
5764738
make black happy
GaelVaroquaux Jul 21, 2023
eaa8c7d
Make test work on my machine
GaelVaroquaux Jul 21, 2023
dee3b98
Refactor tests
GaelVaroquaux Jul 21, 2023
34e29ae
make black happy
GaelVaroquaux Jul 21, 2023
e987fc0
make black happy
GaelVaroquaux Jul 21, 2023
f81593b
Tests working on all platforms
GaelVaroquaux Jul 21, 2023
5ce6e68
Tests more robust
GaelVaroquaux Jul 21, 2023
9f640dc
Last commit was an error
GaelVaroquaux Jul 21, 2023
5d67623
Try to make tests more robust
GaelVaroquaux Jul 21, 2023
feab5f9
tired
GaelVaroquaux Jul 21, 2023
19471d8
wired
GaelVaroquaux Jul 21, 2023
8f9559f
Try something
GaelVaroquaux Jul 21, 2023
c7dbeab
Cleaner code
GaelVaroquaux Jul 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ Minor changes
* Moved the default storage location of data to the user's home folder.
:pr:`652` by :user:`Felix Lefebvre <flefebv>`

* Fixed bug when using :class:`TableVectorizer`'s `transform` method on
categorical columns with missing values.
:pr:`644` by :user:`Leo Grinsztajn <LeoGrin>`

* :class:`TableVectorizer` never output a sparse matrix by default. This can be changed by
increasing the `sparse_threshold` parameter. :pr:`646` by :user:`Leo Grinsztajn <LeoGrin>`

Before skrub: dirty_cat
========================

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/bench_fuzzy_join_sparse_vs_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ def benchmark(
dataset_name: str,
analyzer: Literal["char_wb", "char", "word"],
ngram_range: tuple,
data_home: Path | str = None,
data_directory: str = "benchmarks_data",
data_home: Path | str | None = None,
data_directory: str | None = "benchmarks_data",
):
left_table, right_table, gt = fetch_big_data(
dataset_name=dataset_name, data_home=data_home, data_directory=data_directory
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/utils/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def get_local_data(
dataset_name: str, data_home: Path | str = None, data_directory: str = None
dataset_name: str, data_home: Path | str | None = None, data_directory: str | None = None
):
"""Get the path to the local datasets."""
data_directory = get_data_dir(data_directory, data_home)
Expand All @@ -24,8 +24,8 @@ def get_local_data(
def fetch_data(
dataset_name: str,
save: bool = True,
data_home: Path | str = None,
data_directory: str = None,
data_home: Path | str | None = None,
data_directory: str | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Fetch datasets from https://github.com/Yeye-He/Auto-Join/tree/master/autojoin-Benchmark

Expand All @@ -41,7 +41,7 @@ def fetch_data(
The path to the root data directory.
By default, will point to the skrub data directory.

data_directory: str, default=None
data_directory: str, optional
The name of the subdirectory in which data is stored.

Returns
Expand Down Expand Up @@ -83,8 +83,8 @@ def fetch_big_data(
dataset_name: str,
data_type: str = "Dirty",
save: bool = True,
data_home: Path | str = None,
data_directory: str = None,
data_home: Path | str | None = None,
data_directory: str | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Fetch datasets from https://github.com/anhaidgroup/deepmatcher/blob/master/Datasets.md

Expand All @@ -104,7 +104,7 @@ def fetch_big_data(
The path to the root data directory.
By default, will point to the skrub data directory.

data_directory: str, default=None
data_directory: str, optional
The name of the subdirectory in which data is stored.

Returns
Expand Down
92 changes: 50 additions & 42 deletions skrub/datasets/_fetching.py
LilianBoulard marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ class DatasetInfoOnly:

def _fetch_openml_dataset(
dataset_id: int,
data_home: Path | str | None = None,
data_directory: Path | None = None,
) -> dict[str, Any]:
"""Gets a dataset from OpenML (https://www.openml.org).
"""
Gets a dataset from OpenML (https://www.openml.org).

Parameters
----------
dataset_id : int
The ID of the dataset to fetch.
data_home : Path or str, optional
The path to the root data directory.
By default, the skrub data directory.
data_directory : Path or str, optional
LilianBoulard marked this conversation as resolved.
Show resolved Hide resolved
The directory where the dataset is stored.
By default, a subdirectory "openml" in the skrub data directory.

Returns
-------
Expand All @@ -164,7 +165,8 @@ def _fetch_openml_dataset(
The local path leading to the dataset,
saved as a CSV file.
"""
data_directory = get_data_dir("openml", data_home)
if data_directory is None:
data_directory = get_data_dir(name="openml")

# Make path absolute
data_directory = data_directory.resolve()
Expand Down Expand Up @@ -222,17 +224,17 @@ def _fetch_openml_dataset(

def _fetch_world_bank_data(
indicator_id: str,
data_home: Path | str | None = None,
data_directory: Path | None = None,
) -> dict[str, Any]:
"""Gets a dataset from World Bank open data platform (https://data.worldbank.org/).

Parameters
----------
indicator_id : str
The ID of the indicator's dataset to fetch.
data_home : Path or str, optional
The path to the root data directory.
By default, the skrub data directory.
data_directory : Path or str, optional
LilianBoulard marked this conversation as resolved.
Show resolved Hide resolved
The directory where the dataset is stored.
By default, a subdirectory "world_bank" in the skrub data directory.

Returns
-------
Expand All @@ -247,7 +249,8 @@ def _fetch_world_bank_data(
The local path leading to the dataset,
saved as a CSV file.
"""
data_directory = get_data_dir("world_bank", data_home)
if data_directory is None:
data_directory = get_data_dir(name="world_bank")

csv_path = (data_directory / f"{indicator_id}.csv").resolve()
data_directory.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -304,17 +307,17 @@ def _fetch_world_bank_data(

def _fetch_figshare(
figshare_id: str,
data_home: Path | str | None = None,
data_directory: Path | None = None,
) -> dict[str, Any]:
"""Fetch a dataset from figshare using the download ID number.

Parameters
----------
figshare_id : str
The ID of the dataset to fetch.
data_home : Path or str, optional
The path to the root data directory.
By default, the skrub data directory.
data_directory : Path or str, optional
LilianBoulard marked this conversation as resolved.
Show resolved Hide resolved
The directory where the dataset is stored.
By default, a subdirectory "figshare" in the skrub data directory.

Returns
-------
Expand All @@ -333,7 +336,9 @@ def _fetch_figshare(
The files are read and returned in parquet format, this function needs
pyarrow installed to run correctly.
"""
data_directory = get_data_dir("figshare", data_home)
if data_directory is None:
data_directory = get_data_dir(name="figshare")

parquet_path = (data_directory / f"figshare_{figshare_id}.parquet").resolve()
data_directory.mkdir(parents=True, exist_ok=True)
url = f"https://ndownloader.figshare.com/files/{figshare_id}"
Expand Down Expand Up @@ -371,6 +376,7 @@ def _fetch_figshare(
"pyarrow", extra="pyarrow is required for parquet support."
)
from pyarrow.parquet import ParquetFile

try:
filehandle, _ = urllib.request.urlretrieve(url)
df = ParquetFile(filehandle)
Expand Down Expand Up @@ -555,7 +561,7 @@ def _fetch_dataset_as_dataclass(
dataset_id: int | str,
target: str | None,
load_dataframe: bool,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
read_csv_kwargs: dict | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches a dataset from a source, and returns it as a dataclass.
Expand All @@ -567,7 +573,7 @@ def _fetch_dataset_as_dataclass(
pass `load_dataframe=False`.

To save/load the dataset to/from a specific directory,
pass `data_home`. If `None`, uses the default skrub
pass `data_directory`. If `None`, uses the default skrub
data directory.

If the dataset doesn't have a target (unsupervised learning or inapplicable),
Expand All @@ -581,12 +587,15 @@ def _fetch_dataset_as_dataclass(
:obj:`DatasetInfoOnly`
If `load_dataframe=False`
"""
if isinstance(data_directory, str):
data_directory = Path(data_directory)

if source == "openml":
info = _fetch_openml_dataset(dataset_id, data_home)
info = _fetch_openml_dataset(dataset_id, data_directory)
elif source == "world_bank":
info = _fetch_world_bank_data(dataset_id, data_home)
info = _fetch_world_bank_data(dataset_id, data_directory)
elif source == "figshare":
info = _fetch_figshare(dataset_id, data_home)
info = _fetch_figshare(dataset_id, data_directory)
else:
raise ValueError(f"Unknown source {source!r}")

Expand Down Expand Up @@ -635,7 +644,7 @@ def fetch_employee_salaries(
load_dataframe: bool = True,
drop_linked: bool = True,
drop_irrelevant: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the employee salaries dataset (regression), available at https://openml.org/d/42125

Expand All @@ -654,9 +663,8 @@ def fetch_employee_salaries(
Drops column "full_name", which is usually irrelevant to the
statistical analysis.

data_home: Path or str, default=None
The path to the root data directory.
By default, will point to the skrub data directory.
data_directory: Path or str, optional
LilianBoulard marked this conversation as resolved.
Show resolved Hide resolved
The directory where the dataset is stored.

Returns
-------
Expand All @@ -677,7 +685,7 @@ def fetch_employee_salaries(
"na_values": ["?"],
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)
if load_dataframe:
if drop_linked:
Expand All @@ -693,7 +701,7 @@ def fetch_employee_salaries(
def fetch_road_safety(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the road safety dataset (classification), available at https://openml.org/d/42803

Expand All @@ -720,14 +728,14 @@ def fetch_road_safety(
"na_values": ["?"],
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_medical_charge(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the medical charge dataset (regression), available at https://openml.org/d/42720

Expand Down Expand Up @@ -759,14 +767,14 @@ def fetch_medical_charge(
"escapechar": "\\",
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_midwest_survey(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the midwest survey dataset (classification), available at https://openml.org/d/42805

Expand All @@ -791,14 +799,14 @@ def fetch_midwest_survey(
"escapechar": "\\",
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_open_payments(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the open payments dataset (classification), available at https://openml.org/d/42738

Expand All @@ -825,14 +833,14 @@ def fetch_open_payments(
"na_values": ["?"],
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_traffic_violations(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the traffic violations dataset (classification), available at https://openml.org/d/42132

Expand Down Expand Up @@ -861,14 +869,14 @@ def fetch_traffic_violations(
"na_values": ["?"],
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_drug_directory(
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches the drug directory dataset (classification), available at https://openml.org/d/43044

Expand All @@ -894,15 +902,15 @@ def fetch_drug_directory(
"escapechar": "\\",
},
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_world_bank_indicator(
indicator_id: str,
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches a dataset of an indicator from the World Bank open data platform.

Expand All @@ -925,15 +933,15 @@ def fetch_world_bank_indicator(
dataset_id=indicator_id,
target=None,
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)


def fetch_figshare(
figshare_id: str,
*,
load_dataframe: bool = True,
data_home: Path | str | None = None,
data_directory: Path | str | None = None,
) -> DatasetAll | DatasetInfoOnly:
"""Fetches a table of from figshare.

Expand All @@ -951,5 +959,5 @@ def fetch_figshare(
dataset_id=figshare_id,
target=None,
load_dataframe=load_dataframe,
data_home=data_home,
data_directory=data_directory,
)
Loading