Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for catalog entries that point to public data #51

Merged
merged 11 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/squirrel_datasets_core/py.typed
2 changes: 1 addition & 1 deletion src/squirrel_datasets_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.10"
__version__ = "0.1.11"
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import numpy as np
import requests
from PIL import Image

from PIL import Image, UnidentifiedImageError
from squirrel.driver import IterDriver
from squirrel.iterstream.source import IterableSource

Expand Down Expand Up @@ -79,6 +78,9 @@ def _map_fn(record: t.Dict[str, str]) -> t.Dict[str, t.Union[str, np.ndarray]]:
except (urllib.error.HTTPError, urllib.error.URLError) as e:
logger.info(f"Cannot access url {url} due to Error {e}")
record["error"] = True
except UnidentifiedImageError as e:
logger.info(f"Cannot open image at {url} due to Error {e}")
record["error"] = True

return record

Expand Down
3 changes: 3 additions & 0 deletions src/squirrel_datasets_core/py.typed
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This marker file declares that the package supports type checking. For details, you can refer to:
- PEP561: https://www.python.org/dev/peps/pep-0561/
- mypy docs: https://mypy.readthedocs.io/en/stable/installed_packages.html
7 changes: 7 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
"""

import pytest
from squirrel.catalog import Catalog


@pytest.fixture()
def unit_test_fixture() -> str:
"""Fixture that is only used in the unit test scope."""
return "unit test string"


@pytest.fixture()
def plugin_catalog() -> Catalog:
"""Create catalog from plugins."""
return Catalog.from_plugins()
16 changes: 15 additions & 1 deletion test/test_datasets/test_allenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import Iterator, Tuple

import pytest
from squirrel_datasets_core.datasets.allenai_c4 import C4DatasetDriver
from squirrel.catalog import Catalog

from mock_utils import create_random_dict, save_gzip
from squirrel_datasets_core.datasets.allenai_c4 import C4DatasetDriver
from squirrel_datasets_core.datasets.allenai_c4.constants import C4_MULTILINGUAL_CONFIG


def mock_allenai_data(tmp_path: Path) -> Iterator[Tuple[str, str, int, Path]]:
Expand Down Expand Up @@ -52,3 +54,15 @@ def test_allenai(tmp_path: Path) -> None:
)
with pytest.raises(ValueError):
driver.select("en").get_iter().collect()


@pytest.mark.skip(reason="Dataset is on public storage.")
def test_allenai_public_data(plugin_catalog: Catalog) -> None:
"""Test loading a single language from the C4 corpus."""
driver: C4DatasetDriver = plugin_catalog["c4"].get_driver()
assert sorted(driver.available_languages) == sorted(C4_MULTILINGUAL_CONFIG.keys())

TAKE = 10
it = driver.select("af", "valid").get_iter(shuffle_key_buffer=1, shuffle_item_buffer=1, prefetch_buffer=1)
data = it.take(TAKE).tqdm().collect()
assert len(data) == TAKE
16 changes: 15 additions & 1 deletion test/test_datasets/test_cc100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
import pytest
from squirrel_datasets_core.datasets.cc100 import CC100Driver
from squirrel.catalog import Catalog

from mock_utils import create_random_str, save_xz
from squirrel_datasets_core.datasets.cc100 import CC100Driver
from squirrel_datasets_core.datasets.cc100.constants import CC_100_CONFIG


def mock_cc100_data(samples: int, tmp_path: Path) -> Path:
Expand Down Expand Up @@ -47,3 +49,15 @@ def test_cc100(tmp_path: Path) -> None:

with pytest.raises(KeyError):
driver.select("en").get_iter().collect()


@pytest.mark.skip(reason="Dataset is on public storage.")
def test_cc100_public_data(plugin_catalog: Catalog) -> None:
"""Test loading a single language from the CC100 corpus."""
driver: CC100Driver = plugin_catalog["cc100"].get_driver()
assert sorted(driver.available_languages) == sorted(CC_100_CONFIG.keys())

TAKE = 10
it = driver.select("as").get_iter(shuffle_key_buffer=1, shuffle_item_buffer=1, prefetch_buffer=1)
data = it.take(TAKE).tqdm().collect()
assert len(data) == TAKE
15 changes: 14 additions & 1 deletion test/test_datasets/test_conceptual_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from unittest.mock import patch

import numpy as np
import pytest
import requests
from PIL import Image
from squirrel_datasets_core.datasets.conceptual_captions.driver import CC12MDriver
from squirrel.catalog import Catalog

from mock_utils import create_random_str
from squirrel_datasets_core.datasets.conceptual_captions.driver import CC12MDriver

SHAPE = (10, 10, 3)

Expand Down Expand Up @@ -87,3 +89,14 @@ def test_conceptual_captions_driver() -> None:
assert sample["url"] is not None
assert not sample["error"]
assert sample["image"].shape == SHAPE


@pytest.mark.skip(reason="Dataset is on public storage.")
def test_conceptual_captions_public_data(plugin_catalog: Catalog) -> None:
"""Test the conceptual captions loader"""
driver: CC12MDriver = plugin_catalog["conceptual-captions-12m"].get_driver()
TAKE = 10
it = driver.get_iter(shuffle_key_buffer=1, shuffle_item_buffer=1, prefetch_buffer=1).take(TAKE)
data = it.collect()
assert sorted(data[0].keys()) == sorted(["url", "error", "image", "caption"])
assert len(data) == TAKE
39 changes: 39 additions & 0 deletions test/test_datasets/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Tuple

import pytest
from squirrel.catalog import Catalog

from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver

TAKE = 10


@pytest.mark.skip(reason="Dataset is on public storage.")
def test_mnist_public_data(plugin_catalog: Catalog) -> None:
"""Test loading MNIST hosted by Hugging Face"""
driver: HuggingfaceDriver = plugin_catalog["mnist"].get_driver()
data = driver.get_iter(split="train").take(TAKE).collect()
assert len(data) == TAKE


@pytest.mark.parametrize("cifar_set", ["cifar10", "cifar100"])
@pytest.mark.parametrize("split", ["train", "test"])
@pytest.mark.skip(reason="Dataset is on public storage.")
def test_cifar_public_data(plugin_catalog: Catalog, cifar_set: str, split: str) -> None:
"""Test loading CIFAR10 hosted by Hugging Face."""
# version==1 is huggingface
driver: HuggingfaceDriver = plugin_catalog[(cifar_set, 1)].get_driver()
data = driver.get_iter(split=split).take(TAKE).collect()
assert len(data) == TAKE


@pytest.mark.parametrize(
"catalog_key", [("wikitext-103-raw", 1), ("wikitext-103", 1), ("wikitext-2-raw", 1), ("wikitext-2", 1)]
)
@pytest.mark.parametrize("split", ["train", "test"])
@pytest.mark.skip(reason="Dataset is on public storage.")
def test_wikitext_public_data(plugin_catalog: Catalog, catalog_key: Tuple[str, int], split: str) -> None:
"""Test loading wikitext datasets hosted by Hugging Face."""
driver: HuggingfaceDriver = plugin_catalog[catalog_key].get_driver()
data = driver.get_iter(split=split).take(TAKE).collect()
assert len(data) == TAKE
34 changes: 34 additions & 0 deletions test/test_datasets/test_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from squirrel.catalog import Catalog

TAKE = 1


@pytest.mark.skip(reason="Dataset is on public storage.")
def test_caltech_public_data(plugin_catalog: Catalog) -> None:
"""Test loading Caltech101 via torchvision."""
try:
plugin_catalog["caltech"].get_driver().get_iter().take(TAKE).join()
except RuntimeError as e:
if str(e).startswith("The daily quota of the file 101_ObjectCategories.tar.gz is exceeded"):
pytest.skip(msg=f"Failed to download data due to quota limit. original error: {str(e)}")
else:
raise e


@pytest.mark.parametrize("cifar_set", ["cifar10", "cifar100"])
@pytest.mark.parametrize("split", ["train", "test"])
@pytest.mark.skip(reason="Dataset is on public storage.")
def test_cifar_public_data(plugin_catalog: Catalog, cifar_set: str, split: str) -> None:
"""Test loading CIFAR-10 via torchvision."""
# version==2 is torchvision
version = 2
# torchvision.dataset.CIFAR10 and CIFAR100 use `train: bool` instead of `split: str`
plugin_catalog[cifar_set][version].get_driver().get_iter(train=(split == "train")).take(TAKE).join()


@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"])
@pytest.mark.skip(reason="Dataset is on public storage.")
def test_emnist_public_data(plugin_catalog: Catalog, split: str) -> None:
"""Test loading EMNIST via torchvision."""
plugin_catalog["emnist"].get_driver().get_iter(split=split).take(TAKE).join()