-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Use recommended tmp_path instead of tmpdir * Add camvid test * Test datascience bowl * Skip type checking for coverage report * Add test for allenai * Fix dependabot alert * Add test for bdd100k * Add test for cc100 * Add test for conceptual captions * Add test for casting quality * Add test for monthly german tweets * Add imagenet test * Fix conceptual captions dimensions * Add mock to requirements * Update test to work with unittest.mock directly * Remove mock from requirements * Correct mistakes in requirements.txt * Update test/test_datasets/mock_utils.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/mock_utils.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_cc100.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_conceptual_captions.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_conceptual_captions.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_imagenet.py Co-authored-by: Alp Arıbal <[email protected]> * Integrate feedback from Alp * Check all examples instead of only one * Update test/test_datasets/test_allenai.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_allenai.py Co-authored-by: Alp Arıbal <[email protected]> * Update test/test_datasets/test_conceptual_captions.py Co-authored-by: Alp Arıbal <[email protected]> * Update src/squirrel_datasets_core/datasets/allenai_c4/allenai_c4_multilingual.py Co-authored-by: Alp Arıbal <[email protected]> * Expect value error instead of runtime error * Fix typehints * switch requirements back to python 3.8 * Bump version Co-authored-by: Alp Arıbal <[email protected]>
- Loading branch information
Showing
19 changed files
with
757 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.1.8" | ||
__version__ = "0.1.9" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import gzip | ||
import lzma as xz | ||
import random | ||
import string | ||
from pathlib import Path | ||
from typing import Dict, List, Tuple, Union | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
def create_random_str(length: int = 10) -> str: | ||
"""Generate a random file name""" | ||
return "".join(random.choices(string.ascii_uppercase + string.digits, k=length)) | ||
|
||
|
||
def create_image(folder: Path, image_name: str, resolution: Tuple, format: str = "png") -> None: | ||
"""Create a random image in a given directory""" | ||
folder.mkdir(exist_ok=True, parents=True) | ||
|
||
img = np.random.rand(resolution[0], resolution[1], 3) * 255 | ||
|
||
if format == "png": | ||
img = Image.fromarray(img.astype("uint8")).convert("RGBA") | ||
else: | ||
img = Image.fromarray(img.astype("uint8")).convert("RGB") | ||
|
||
img.save(folder / f"{image_name}.{format}") | ||
|
||
|
||
def create_image_folder(folder: Path, image_names: Union[int, List], resolution: Tuple, format: str = "png") -> None: | ||
"""Create an example directory filled with random images""" | ||
if isinstance(image_names, int): | ||
image_names = [create_random_str() for _ in range(image_names)] | ||
|
||
for name in image_names: | ||
create_image(folder, name, resolution, format) | ||
|
||
|
||
def create_random_dict(attributes: List[str]) -> Dict: | ||
"""Create random blob of json""" | ||
return {a: create_random_str() for a in attributes} | ||
|
||
|
||
def save_gzip(file: Path, content: str) -> None: | ||
"""Save string as gz""" | ||
with gzip.open(file, "wb") as f: | ||
f.write(content.encode()) | ||
|
||
|
||
def save_xz(file: Path, content: str) -> None: | ||
"""Save string as gz with xz compression""" | ||
with open(file, "wb") as f: | ||
f.write(xz.compress(content.encode())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import json | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import Iterator, Tuple | ||
|
||
import pytest | ||
from squirrel_datasets_core.datasets.allenai_c4 import C4DatasetDriver | ||
|
||
from mock_utils import create_random_dict, save_gzip | ||
|
||
|
||
def mock_allenai_data(tmp_path: Path) -> Iterator[Tuple[str, str, int, Path]]: | ||
"""Create a dataset mock for allenai""" | ||
for language, split, samples in [ | ||
("af", "train", 64), | ||
("af", "valid", 1), | ||
("am", "train", 3), | ||
("am", "valid", 2), | ||
("zu", "train", 5), | ||
("zu", "valid", 5), | ||
]: | ||
contents = [json.dumps(create_random_dict(["text", "timestamp", "url"])) for _ in range(samples)] | ||
record = "\n".join(contents) | ||
|
||
save_path = tmp_path / language / split / "record.json.gz" | ||
Path(save_path).parent.mkdir(parents=True, exist_ok=True) | ||
save_gzip(save_path, record) | ||
yield (language, split, samples, save_path) | ||
|
||
|
||
def test_allenai(tmp_path: Path) -> None: | ||
"""Unit test for allenai c4 driver with mocked data""" | ||
mock_data = list(mock_allenai_data(tmp_path)) | ||
config = defaultdict(dict) | ||
|
||
for language, split, _, save_path in mock_data: | ||
config[language][split] = [save_path] | ||
|
||
driver = C4DatasetDriver(config) | ||
|
||
for language, split, samples, _ in mock_data: | ||
assert len(driver.select(language, split).get_iter().collect()) == samples | ||
|
||
for sample in driver.select(language, split).get_iter(): | ||
assert sample["text"] is not None | ||
assert sample["timestamp"] is not None | ||
assert sample["url"] is not None | ||
assert language in driver.available_languages | ||
|
||
assert len(driver.select().get_iter().collect()) == sum( | ||
samples for _, split, samples, _ in mock_data if split == "train" | ||
) | ||
with pytest.raises(ValueError): | ||
driver.select("en").get_iter().collect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from pathlib import Path | ||
from typing import Iterator, Tuple | ||
|
||
import pytest | ||
from squirrel_datasets_core.datasets.bdd100k import BDD100KDriver | ||
|
||
from mock_utils import create_image_folder, create_random_str | ||
|
||
|
||
def mock_bdd100k_data(samples: int, tmp_path: Path, image_shape: Tuple) -> None: | ||
"""Create a dataset mock for bdd100k""" | ||
image_names_train = [create_random_str() for _ in range(samples)] | ||
image_names_test = [create_random_str() for _ in range(samples)] | ||
|
||
create_image_folder(tmp_path / "images/10k/train", image_names_train, image_shape) | ||
create_image_folder(tmp_path / "images/10k/test", image_names_test, image_shape) | ||
create_image_folder(tmp_path / "labels/sem_seg/masks/train", image_names_train, image_shape) | ||
|
||
|
||
@pytest.mark.parametrize("parse_label", [True, False]) | ||
@pytest.mark.parametrize("parse_image", [True, False]) | ||
@pytest.mark.parametrize("split", ["train", "test"]) | ||
def test_bdd100k(tmp_path: Path, parse_label: bool, parse_image: bool, split: str) -> None: | ||
"""Unit test for bdd100k driver (segmentation) with mocked data""" | ||
N = 10 | ||
image_shape = (1280, 720) | ||
|
||
mock_bdd100k_data(N, tmp_path, image_shape) | ||
driver = BDD100KDriver(tmp_path) | ||
|
||
assert len(driver.get_iter(split).collect()) == N | ||
|
||
for sample in driver.get_iter(split): | ||
assert sample["split"] == split | ||
assert sample["image"].shape == image_shape + (4,) | ||
|
||
if split == "train": | ||
# labels exist | ||
assert Path(sample["image_url"]).stem == Path(sample["label_url"]).stem | ||
assert sample["label"].shape == image_shape + (4,) | ||
|
||
def _test_hook(it: Iterator) -> Iterator: | ||
for element in it: | ||
yield element | ||
break | ||
|
||
assert ( | ||
len(driver.get_iter(split, hooks=[_test_hook], parse_label=parse_label, parse_image=parse_image).collect()) == 1 | ||
) |
Oops, something went wrong.