diff --git a/skrub/datasets/_fetching.py b/skrub/datasets/_fetching.py index ed8ee9d9e..6195d30b1 100644 --- a/skrub/datasets/_fetching.py +++ b/skrub/datasets/_fetching.py @@ -24,6 +24,7 @@ import pandas as pd from sklearn import __version__ as sklearn_version from sklearn.datasets import fetch_openml +from sklearn.datasets._base import _sha256 from skrub._utils import import_optional_dependency, parse_version from skrub.datasets._utils import get_data_dir @@ -58,6 +59,20 @@ TRAFFIC_VIOLATIONS_ID: int = 42132 DRUG_DIRECTORY_ID: int = 43044 +# A dictionnary storing the sha256 hashes of the figshare files +figshare_id_to_hash = { + 39142985: "47d73381ef72b050002a8642194c6718a4954ec9e6c556f4c4ddc6ed84ceec92", + 39149066: "e479cf9741a90c40401697e7fa54409e3b9cfa09f27502877382e64e86fbfcd0", + 39149069: "7b0dcdb15d3aeecba6022c929665ee064f6fb4b8b94186a6e89b6fbc781b3775", + 39149072: "4f58f15168bb8a6cc8b152bd48995bc7d1a4d4d89a9e22d87aa51ccf30118122", + 39149075: "7037603362af1d4bf73551d50644c0957cb91d2b4892e75413f57f415962029a", + 39254360: "531130c714ba6ee9902108d4010f42388aa9c0b3167d124cd57e2c632df3e05a", + 39266300: "37b23b2c37a1f7ff906bc7951cbed4be15d8417dad0762092282f7b491cf8c21", + 39266678: "4e041322985e078de8b08acfd44b93a5ce347c1e501e9d869651e753de747ba1", + 40019230: "4d43fed75dba1e59a5587bf31c1addf2647a1f15ebea66e93177ccda41e18f2f", + 40019788: "67ae86496c8a08c6cc352f573160a094f605e7e0da022eb91c603abb7edf3747", +} + @dataclass(unsafe_hash=True) class Details: @@ -373,6 +388,19 @@ def _fetch_figshare( try: filehandle, _ = urllib.request.urlretrieve(url) + + # checksum the file + checksum = _sha256(filehandle) + if figshare_id in figshare_id_to_hash: + expected_checksum = figshare_id_to_hash[figshare_id] + if checksum != expected_checksum: + raise OSError( + f"{filehandle!r} SHA256 checksum differs from " + f"expected ({checksum}!={expected_checksum}) ; " + f"file is probably corrupted. Please try again. " + f"If the error persists, please open an issue on GitHub. " + ) + df = ParquetFile(filehandle) record = df.iter_batches( batch_size=1_000_000,