Skip to content

Commit

Permalink
Fix ConnectionError for gated datasets and unauthenticated users (#7110)
Browse files Browse the repository at this point in the history
* Test load_dataset raises DatasetNotFoundError for unauthenticated user

* Raise DatasetNotFoundError for gated and unauthenticated

* Rename function
  • Loading branch information
albertvillanova authored Aug 20, 2024
1 parent fb8ae4d commit 90b1d94
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
relative_to_absolute_path,
url_or_path_join,
)
from .utils.hub import hf_dataset_url
from .utils.hub import check_auth, hf_dataset_url
from .utils.info_utils import VerificationMode, is_small_dataset
from .utils.logging import get_logger
from .utils.metadata import MetadataConfigs
Expand Down Expand Up @@ -1585,19 +1585,24 @@ def dataset_module_factory(
requests.exceptions.ConnectionError,
) as e:
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e
except GatedRepoError as e:
message = f"Dataset '{path}' is a gated dataset on the Hub."
if "401 Client Error" in str(e):
message += " You must be authenticated to access it."
elif "403 Client Error" in str(e):
message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access."
raise DatasetNotFoundError(message) from e
except RevisionNotFoundError as e:
raise DatasetNotFoundError(
f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub."
) from e
except RepositoryNotFoundError as e:
raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e
if dataset_info.gated:
try:
check_auth(hf_api, repo_id=path, token=download_config.token)
except GatedRepoError as e:
message = f"Dataset '{path}' is a gated dataset on the Hub."
if "401 Client Error" in str(e):
message += " You must be authenticated to access it."
elif "403 Client Error" in str(e):
message += (
f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access."
)
raise DatasetNotFoundError(message) from e

if filename in [sibling.rfilename for sibling in dataset_info.siblings]: # contains a dataset script
fs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
Expand Down
8 changes: 8 additions & 0 deletions src/datasets/utils/hub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from functools import partial

from huggingface_hub import hf_hub_url
from huggingface_hub.utils import get_session, hf_raise_for_status


hf_dataset_url = partial(hf_hub_url, repo_type="dataset")


def check_auth(hf_api, repo_id, token=None):
headers = hf_api._build_hf_headers(token=token)
path = f"{hf_api.endpoint}/api/datasets/{repo_id}/auth-check"
r = get_session().get(path, headers=headers)
hf_raise_for_status(r)
33 changes: 33 additions & 0 deletions tests/fixtures/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import requests
from huggingface_hub.hf_api import HfApi, RepositoryNotFoundError
from huggingface_hub.utils import hf_raise_for_status


CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__"
Expand Down Expand Up @@ -74,6 +75,38 @@ def _temporary_repo(repo_id: Optional[str] = None):
return _temporary_repo


@pytest.fixture(scope="session")
def _hf_gated_dataset_repo_txt_data(hf_api: HfApi, hf_token, text_file_content):
repo_name = f"repo_txt_data-{int(time.time() * 10e6)}"
repo_id = f"{CI_HUB_USER}/{repo_name}"
hf_api.create_repo(repo_id, token=hf_token, repo_type="dataset")
hf_api.upload_file(
token=hf_token,
path_or_fileobj=text_file_content.encode(),
path_in_repo="data/text_data.txt",
repo_id=repo_id,
repo_type="dataset",
)
path = f"{hf_api.endpoint}/api/datasets/{repo_id}/settings"
repo_settings = {"gated": "auto"}
r = requests.put(
path,
headers={"authorization": f"Bearer {hf_token}"},
json=repo_settings,
)
hf_raise_for_status(r)
yield repo_id
try:
hf_api.delete_repo(repo_id, token=hf_token, repo_type="dataset")
except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error
pass


@pytest.fixture()
def hf_gated_dataset_repo_txt_data(_hf_gated_dataset_repo_txt_data, ci_hub_config, ci_hfh_hf_hub_url):
return _hf_gated_dataset_repo_txt_data


@pytest.fixture(scope="session")
def hf_private_dataset_repo_txt_data_(hf_api: HfApi, hf_token, text_file_content):
repo_name = f"repo_txt_data-{int(time.time() * 10e6)}"
Expand Down
13 changes: 13 additions & 0 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from datasets.config import METADATA_CONFIGS_FIELD
from datasets.data_files import get_data_patterns
from datasets.exceptions import DatasetNotFoundError
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
FolderBasedBuilderConfig,
Expand Down Expand Up @@ -953,3 +954,15 @@ def test_get_data_patterns(self, temporary_repo, tmp_path):
assert data_file_patterns == {
"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]
}

@pytest.mark.parametrize("dataset", ["gated", "private"])
def test_load_dataset_raises_for_unauthenticated_user(
self, dataset, hf_gated_dataset_repo_txt_data, hf_private_dataset_repo_txt_data
):
dataset_ids = {
"gated": hf_gated_dataset_repo_txt_data,
"private": hf_private_dataset_repo_txt_data,
}
dataset_id = dataset_ids[dataset]
with pytest.raises(DatasetNotFoundError):
_ = load_dataset(dataset_id, token=False)

0 comments on commit 90b1d94

Please sign in to comment.