Skip to content

Commit

Permalink
Add tests for drivers (#42)
Browse files Browse the repository at this point in the history
* Use recommended tmp_path instead of tmpdir

* Add camvid test

* Test datascience bowl

* Skip type checking for coverage report

* Add test for allenai

* Fix dependabot alert

* Add test for bdd100k

* Add test for cc100

* Add test for conceptual captions

* Add test for casting quality

* Add test for monthly german tweets

* Add imagenet test

* Fix conceptual captions dimensions

* Add mock to requirements

* Update test to work with unittest.mock directly

* Remove mock from requirements

* Correct mistakes in requirements.txt

* Update test/test_datasets/mock_utils.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/mock_utils.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_cc100.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_conceptual_captions.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_conceptual_captions.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_imagenet.py

Co-authored-by: Alp Arıbal <[email protected]>

* Integrate feedback from Alp

* Check all examples instead of only one

* Update test/test_datasets/test_allenai.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_allenai.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update test/test_datasets/test_conceptual_captions.py

Co-authored-by: Alp Arıbal <[email protected]>

* Update src/squirrel_datasets_core/datasets/allenai_c4/allenai_c4_multilingual.py

Co-authored-by: Alp Arıbal <[email protected]>

* Expect value error instead of runtime error

* Fix typehints

* switch requirements back to python 3.8

* Bump version

Co-authored-by: Alp Arıbal <[email protected]>
  • Loading branch information
Winfried Lötzsch and AlpAribal committed Aug 9, 2022
1 parent 23dd089 commit d71ace2
Show file tree
Hide file tree
Showing 19 changed files with 757 additions and 36 deletions.
5 changes: 4 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ exclude_lines =
^\s*raise ValueError

# skip lines that are only "pass"
^\s*pass\s*$
^\s*pass\s*$

# skip type checking
if TYPE_CHECKING:
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ imagesize==1.3.0 \
--hash=sha256:1db2f82529e53c3e929e8926a1fa9235aa82d0bd0c580359c67ec31b2fddaa8c
importlib-metadata==4.11.2 \
--hash=sha256:d16e8c1deb60de41b8e8ed21c1a7b947b0bc62fab7e1d470bcdf331cea2e6735
importlib-resources==5.8.0 \
--hash=sha256:568c9f16cb204f9decc8d6d24a572eeea27dacbb4cee9e6b03a8025736769751 \
--hash=sha256:7952325ffd516c05a8ad0858c74dff2c3343f136fe66a6002b2623dd1d43f223
iniconfig==1.1.1 \
--hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3
ipykernel==6.9.1 \
Expand Down
2 changes: 1 addition & 1 deletion src/squirrel_datasets_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.8"
__version__ = "0.1.9"
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

import fsspec

from squirrel.driver import MapDriver
from squirrel.serialization import JsonSerializer

if t.TYPE_CHECKING:
if TYPE_CHECKING:
from squirrel.iterstream import Composable


Expand Down Expand Up @@ -50,6 +52,11 @@ def _init_source(self, lang: t.List[str], split: str) -> None:
"""
self._source = []
for iso in lang:
if iso not in self._subsets:
raise ValueError(f"The language {iso} does not exist")
if split not in self._subsets[iso]:
raise ValueError(f"The split {split} does not exist")

self._source += self._subsets[iso][split]

def select(self, lang: t.Optional[t.Union[str, t.List[str]]] = None, split: str = "train") -> C4DatasetDriver:
Expand All @@ -64,6 +71,8 @@ def select(self, lang: t.Optional[t.Union[str, t.List[str]]] = None, split: str
lang = [lang]

self._lang = lang
else:
self._lang = list(self._subsets.keys())

self._split = split
self._init_source(self._lang, self._split)
Expand Down
11 changes: 6 additions & 5 deletions src/squirrel_datasets_core/datasets/bdd100k/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import os
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Dict, List, Optional

from squirrel.driver import IterDriver
from squirrel.iterstream import FilePathGenerator, IterableSource

from squirrel_datasets_core.io.io import load_image

if TYPE_CHECKING:
from squirrel.iterstream import Composable

from squirrel_datasets_core.io.io import load_image


class BDD100KDriver(IterDriver):
"""Driver that can iterate over the samples of the `Berkeley Deep Drive Semantic Segmentation (BDD100K) dataset
Expand Down Expand Up @@ -61,7 +62,7 @@ def load_sample(sample: Dict, parse_image: bool = True, parse_label: bool = True
"""
if parse_image:
sample["image"] = load_image(sample["image_url"])
if parse_label:
if parse_label and "label_url" in sample:
sample["label"] = load_image(sample["label_url"])
return sample

Expand Down Expand Up @@ -97,7 +98,7 @@ def get_iter(
imgs_dir = os.path.join(self.url, "images/10k", split)

if split == "test":
gen = FilePathGenerator(imgs_dir)
gen = FilePathGenerator(imgs_dir).map(lambda x: dict(image_url=x, split=split))
else:
labels_dir = os.path.join(self.url, "labels/sem_seg/masks", split)
gen = FilePathGenerator(imgs_dir).map(
Expand Down
19 changes: 6 additions & 13 deletions src/squirrel_datasets_core/datasets/cc100/cc100.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

import fsspec

from squirrel.driver import MapDriver
from squirrel.iterstream import Composable
from squirrel.iterstream.source import IterableSource

if t.TYPE_CHECKING:
if TYPE_CHECKING:
from squirrel.catalog import Catalog


Expand All @@ -32,14 +33,6 @@ def __init__(

self.compression = compression

@property
def _file_path_generator(self) -> IterableSource:
"""Initializes the source for the CC100Driver to iterate over."""
if self._source is None:
self._init_source(self._lang)

return IterableSource(self._source)

def _init_source(self, lang: t.List[str]) -> None:
"""
Lazy initialization for the source generator.
Expand All @@ -61,14 +54,14 @@ def select(self, lang: t.Optional[t.Union[str, t.List[str]]] = None) -> CC100Dri
lang = [lang]

self._lang = lang
else:
self._lang = list(self._subsets.keys())

self._init_source(self._lang)
return self

def keys(self, **kwargs) -> t.List[str]:
"""Returns the list of urls for the archive file of the selected language(s)."""
if self._source is None:
self._init_source(self._lang)

return self._source

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import sys
import typing as t
import urllib
from typing import TYPE_CHECKING
from urllib.request import Request

import numpy as np
import requests
from PIL import Image

from squirrel.driver import IterDriver
from squirrel.iterstream.source import IterableSource

if t.TYPE_CHECKING:
if TYPE_CHECKING:
from squirrel.catalog import Catalog
from squirrel.iterstream import Composable

Expand Down Expand Up @@ -58,9 +61,8 @@ def _download_image(url: str) -> np.ndarray:
url: location of the image
"""
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
resp = urllib.request.urlopen(req)
image = np.asarray(bytearray(resp.read()), dtype="uint8")
return image
resp = urllib.request.urlopen(req, timeout=1)
return np.array(Image.open(resp))

@staticmethod
def _map_fn(record: t.Dict[str, str]) -> t.Dict[str, t.Union[str, np.ndarray]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from __future__ import annotations

import json
from typing import Iterable, Iterator, List, TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Iterator, List, Union

from squirrel.driver import MapDriver
from squirrel.fsspec.fs import get_fs_from_url
Expand All @@ -18,13 +19,13 @@
class MonthlyGermanTweetsDriver(MapDriver):
name = "raw_monthly_german_tweets"

def __init__(self, folder: str, **kwargs) -> None:
def __init__(self, folder: Union(str, Path), **kwargs) -> None:
"""Init the MonthlyGermanTweets driver.
Args:
folder: Path to the unzipped data dump
"""
self.folder = folder
self.folder = str(folder)
self.compression = "gzip"
self.parse_error_count = 0

Expand All @@ -36,7 +37,8 @@ def parse_archive(self, url: str) -> List[str]:
json_bytes = list(f)

dec = "".join([b.decode("utf-8").strip() for b in json_bytes])[1:-1]
samples = ["{" + elem for elem in dec.split(",{") if not elem.startswith("{")]
samples = [(elem if elem.startswith("{") else "{" + elem) for elem in dec.split(",{")]

return samples

def get(self, url: str) -> Iterator:
Expand All @@ -56,7 +58,7 @@ def get_iter(self, flatten: bool = True, **kwargs) -> Composable:
dataset.
Args:
flatten (bool): Whether to flatten the returned iterable. Defaults to False.
flatten (bool): Whether to flatten the returned iterable. Defaults to True.
**kwargs: Other keyword arguments passed to :py:meth:`MapDriver.get_iter`.
"""

Expand Down
54 changes: 54 additions & 0 deletions test/test_datasets/mock_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import gzip
import lzma as xz
import random
import string
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
from PIL import Image


def create_random_str(length: int = 10) -> str:
"""Generate a random file name"""
return "".join(random.choices(string.ascii_uppercase + string.digits, k=length))


def create_image(folder: Path, image_name: str, resolution: Tuple, format: str = "png") -> None:
"""Create a random image in a given directory"""
folder.mkdir(exist_ok=True, parents=True)

img = np.random.rand(resolution[0], resolution[1], 3) * 255

if format == "png":
img = Image.fromarray(img.astype("uint8")).convert("RGBA")
else:
img = Image.fromarray(img.astype("uint8")).convert("RGB")

img.save(folder / f"{image_name}.{format}")


def create_image_folder(folder: Path, image_names: Union[int, List], resolution: Tuple, format: str = "png") -> None:
"""Create an example directory filled with random images"""
if isinstance(image_names, int):
image_names = [create_random_str() for _ in range(image_names)]

for name in image_names:
create_image(folder, name, resolution, format)


def create_random_dict(attributes: List[str]) -> Dict:
"""Create random blob of json"""
return {a: create_random_str() for a in attributes}


def save_gzip(file: Path, content: str) -> None:
"""Save string as gz"""
with gzip.open(file, "wb") as f:
f.write(content.encode())


def save_xz(file: Path, content: str) -> None:
"""Save string as gz with xz compression"""
with open(file, "wb") as f:
f.write(xz.compress(content.encode()))
54 changes: 54 additions & 0 deletions test/test_datasets/test_allenai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from collections import defaultdict
from pathlib import Path
from typing import Iterator, Tuple

import pytest
from squirrel_datasets_core.datasets.allenai_c4 import C4DatasetDriver

from mock_utils import create_random_dict, save_gzip


def mock_allenai_data(tmp_path: Path) -> Iterator[Tuple[str, str, int, Path]]:
"""Create a dataset mock for allenai"""
for language, split, samples in [
("af", "train", 64),
("af", "valid", 1),
("am", "train", 3),
("am", "valid", 2),
("zu", "train", 5),
("zu", "valid", 5),
]:
contents = [json.dumps(create_random_dict(["text", "timestamp", "url"])) for _ in range(samples)]
record = "\n".join(contents)

save_path = tmp_path / language / split / "record.json.gz"
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
save_gzip(save_path, record)
yield (language, split, samples, save_path)


def test_allenai(tmp_path: Path) -> None:
"""Unit test for allenai c4 driver with mocked data"""
mock_data = list(mock_allenai_data(tmp_path))
config = defaultdict(dict)

for language, split, _, save_path in mock_data:
config[language][split] = [save_path]

driver = C4DatasetDriver(config)

for language, split, samples, _ in mock_data:
assert len(driver.select(language, split).get_iter().collect()) == samples

for sample in driver.select(language, split).get_iter():
assert sample["text"] is not None
assert sample["timestamp"] is not None
assert sample["url"] is not None
assert language in driver.available_languages

assert len(driver.select().get_iter().collect()) == sum(
samples for _, split, samples, _ in mock_data if split == "train"
)
with pytest.raises(ValueError):
driver.select("en").get_iter().collect()
49 changes: 49 additions & 0 deletions test/test_datasets/test_bdd100k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Iterator, Tuple

import pytest
from squirrel_datasets_core.datasets.bdd100k import BDD100KDriver

from mock_utils import create_image_folder, create_random_str


def mock_bdd100k_data(samples: int, tmp_path: Path, image_shape: Tuple) -> None:
"""Create a dataset mock for bdd100k"""
image_names_train = [create_random_str() for _ in range(samples)]
image_names_test = [create_random_str() for _ in range(samples)]

create_image_folder(tmp_path / "images/10k/train", image_names_train, image_shape)
create_image_folder(tmp_path / "images/10k/test", image_names_test, image_shape)
create_image_folder(tmp_path / "labels/sem_seg/masks/train", image_names_train, image_shape)


@pytest.mark.parametrize("parse_label", [True, False])
@pytest.mark.parametrize("parse_image", [True, False])
@pytest.mark.parametrize("split", ["train", "test"])
def test_bdd100k(tmp_path: Path, parse_label: bool, parse_image: bool, split: str) -> None:
"""Unit test for bdd100k driver (segmentation) with mocked data"""
N = 10
image_shape = (1280, 720)

mock_bdd100k_data(N, tmp_path, image_shape)
driver = BDD100KDriver(tmp_path)

assert len(driver.get_iter(split).collect()) == N

for sample in driver.get_iter(split):
assert sample["split"] == split
assert sample["image"].shape == image_shape + (4,)

if split == "train":
# labels exist
assert Path(sample["image_url"]).stem == Path(sample["label_url"]).stem
assert sample["label"].shape == image_shape + (4,)

def _test_hook(it: Iterator) -> Iterator:
for element in it:
yield element
break

assert (
len(driver.get_iter(split, hooks=[_test_hook], parse_label=parse_label, parse_image=parse_image).collect()) == 1
)
Loading

0 comments on commit d71ace2

Please sign in to comment.