From 832a85d83e2fbc3c92408bb83a8f867f42210346 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 3 Feb 2022 11:50:09 +0100 Subject: [PATCH 01/19] apply readthedocs template --- .gitignore | 156 ++++++++++++++++++++++++++++++++++++++++++ .readthedocs.yaml | 29 ++++++++ README.rst | 7 ++ docs/Makefile | 20 ++++++ docs/source/conf.py | 36 ++++++++++ docs/source/index.rst | 4 ++ requirements.in | 1 + 7 files changed, 253 insertions(+) create mode 100644 .gitignore create mode 100644 .readthedocs.yaml create mode 100644 docs/Makefile create mode 100644 docs/source/conf.py create mode 100644 docs/source/index.rst create mode 100644 requirements.in diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..19252fc --- /dev/null +++ b/.gitignore @@ -0,0 +1,156 @@ + +# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..f579983 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,29 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.9" + # You can also specify other tool versions: + # nodejs: "16" + # rust: "1.55" + # golang: "1.17" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +# python: +# install: +# - requirements: docs/requirements.txt diff --git a/README.rst b/README.rst index 73c59cf..f0a755f 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,10 @@ Squirrel Datasets ================= +`squirrel-datasets-core` is a hub where the user can 1) explore existing datasets registered in the data mesh by other users and 2) preprocess their datasets and share them with other users. As an end user, you will +be able to load many publically available datasets with ease and speed with the help of `squirrel`, or load and preprocess +your own datasets with the tools we provide here. + +For preprocessing, we currently support Spark as the main tool to carry out the task. + +Please see our [documentation](https://squirrel-datasets-core.readthedocs.io) for further details. \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..269cadc --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..89d4f8e --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,36 @@ +# Configuration file for the Sphinx documentation builder. +import datetime + +# -- Project information + +project = 'Squirrel Datasets' +copyright = f'{datetime.datetime.now().year}, Merantix Labs GmbH' +author = 'Merantix Labs GmbH' + +release = '0.1' +version = '0.1.0' + +# -- General configuration + +extensions = [ + 'sphinx.ext.duration', + 'sphinx.ext.doctest', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', +] + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), +} +intersphinx_disabled_domains = ['std'] + +templates_path = ['_templates'] + +# -- Options for HTML output + +html_theme = 'sphinx_rtd_theme' + +# -- Options for EPUB output +epub_show_urls = 'footnote' \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..e270bb1 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,4 @@ +Squirrel Datasets Documentation +=============================== + + diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..4170c03 --- /dev/null +++ b/requirements.in @@ -0,0 +1 @@ +sphinx-rtd-theme \ No newline at end of file From 3eb9697087f4d2ef59d6d365282bafa7e59ca02d Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 3 Feb 2022 12:02:57 +0100 Subject: [PATCH 02/19] PR and issue templates --- .github/ISSUE_TEMPLATE/bug_report.md | 19 +++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 18 ++++++++++++++++ .github/PULL_REQUEST_TEMPLATE.md | 25 +++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..a0a32d2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,19 @@ +--- +name: Bug report +about: Create a report to help us improve +title: "[BUG]" +labels: bug +assignees: '' + +--- + +Before you open an issue, please check if a similar issue already exists or has been closed before. + + ### When reporting a bug, please be sure to include the following: + - [ ] A descriptive title + - [ ] An *isolated* way to reproduce the behavior (example: GitHub repository with code isolated to the issue that anyone can clone to observe the problem) + - [ ] What version of `squirrel-datasets` and `squirrel` you're using, and the platform(s) you're running it on + - [ ] What packages or other dependencies you're using + - [ ] The behavior you expect to see and the actual behavior + + See [contributing guideline]() for more detail on what is expected of a bug report. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..c79f399 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,18 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[FEATURE]" +labels: enhancement +assignees: '' + +--- + +Before you open an issue, please check if a similar issue already exists or has been closed before. + +### When you open an issue for a feature request, please add as much detail as possible: + - [ ] A descriptive title + - [ ] A description of the problem you're trying to solve, including *why* you think this is a problem + - [ ] An overview of the suggested solution + - [ ] If the feature changes current behavior, reasons why your solution is better + + See [contributing guideline]() for more detail on what is expected of a feature request. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..3edec8d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,25 @@ +# Description + +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. +List any dependencies that are required for this change. + +Fixes # issue + +## Type of change + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] Documentation update +- [ ] Refactoring including code style reformatting +- [ ] Other (please describe): + +# Checklist: + +- [ ] I have read the [contributing guideline doc] () (external only) +- [ ] I have signed the [CLA] () (external only) +- [ ] Lint and unit tests pass locally with my changes +- [ ] I have kept the PR small so that it can be easily reviewed +- [ ] I have made corresponding changes to the documentation +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] All dependency changes have been reflected in the pip requirement files. From 1269e279cd501d15140905412e7d05db4b902934 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 3 Feb 2022 12:54:41 +0100 Subject: [PATCH 03/19] copy imagenet+kaggle drivers+dependencies --- .../datasets/__init__.py | 0 .../datasets/imagenet/__init__.py | 3 + .../datasets/imagenet/driver.py | 232 ++++++++++++++++++ .../datasets/imagenet/preprocessing.py | 39 +++ .../kaggle_casting_quality/__init__.py | 0 .../datasets/kaggle_casting_quality/driver.py | 39 +++ .../kaggle_casting_quality/preprocessing.py | 38 +++ src/squirrel_datasets_core/driver/__init__.py | 0 src/squirrel_datasets_core/driver/fsspec.py | 73 ++++++ .../preprocessing/__init__.py | 0 .../preprocessing/save_shards.py | 74 ++++++ src/squirrel_datasets_core/spark/__init__.py | 3 + .../spark/setup_spark.py | 101 ++++++++ 13 files changed, 602 insertions(+) create mode 100644 src/squirrel_datasets_core/datasets/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/imagenet/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/imagenet/driver.py create mode 100644 src/squirrel_datasets_core/datasets/imagenet/preprocessing.py create mode 100644 src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py create mode 100644 src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py create mode 100644 src/squirrel_datasets_core/driver/__init__.py create mode 100644 src/squirrel_datasets_core/driver/fsspec.py create mode 100644 src/squirrel_datasets_core/preprocessing/__init__.py create mode 100644 src/squirrel_datasets_core/preprocessing/save_shards.py create mode 100644 src/squirrel_datasets_core/spark/__init__.py create mode 100644 src/squirrel_datasets_core/spark/setup_spark.py diff --git a/src/squirrel_datasets_core/datasets/__init__.py b/src/squirrel_datasets_core/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/squirrel_datasets_core/datasets/imagenet/__init__.py b/src/squirrel_datasets_core/datasets/imagenet/__init__.py new file mode 100644 index 0000000..5c84f16 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/imagenet/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.datasets.imagenet.driver import RawImageNet + +__all__ = ["RawImageNet"] diff --git a/src/squirrel_datasets_core/datasets/imagenet/driver.py b/src/squirrel_datasets_core/datasets/imagenet/driver.py new file mode 100644 index 0000000..705aba6 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/imagenet/driver.py @@ -0,0 +1,232 @@ +import os +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple + +from squirrel.driver import RecordIteratorDriver +from squirrel.driver.file.dataloader import FileLoader +from squirrel.iterstream import Composable, FilePathGenerator + +from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver + + +class RawImageNet(RecordIteratorDriver, TwoDImageFileDriver): + name = "raw_imagenet" + + def __init__( + self, + url: str, + cls_mapping_url: Optional[str] = None, + loc_train_mapping_url: Optional[str] = None, + loc_val_mapping_url: Optional[str] = None, + cls_val_mapping_url: Optional[str] = None, + val_blacklist_url: Optional[str] = None, + **kwargs, + ) -> None: + """Init RawImageNet driver. + + Args: + url (str): Path to the root directory of the dataset. + cls_mapping_url (str, optional): Path to the text file that contains the class mapping. Each line of this + file should be formatted as "{class_id} {class_idx} {class_name}". If not provided, samples will not + have label information. Defaults to None. + loc_train_mapping_url (str, optional): Path to the text file that contains the bounding boxes for the train + set. If not provided, train samples will not have bounding box information. Defaults to None. + loc_val_mapping_url (str, optional): Path to the text file that contains the bounding boxes for the + validation set. If not provided, train samples will not have bounding box information. Defaults to None. + cls_val_mapping_url (str, optional): Path to the text file that contains the groundtruth class indices for + the validation set. Line i (0-indexed) of this file corresponds to the i-th validation sample. If not + provided, validation samples will not have class label information. Defaults to None. + val_blacklist_url (str, optional): Path to the text file that contains the blacklisted validation samples + Each line of this file should contain a sample index. If not provided, validation samples will not be + filtered. Defaults to None. + """ + self.path = url + self.cls_mapping_path = cls_mapping_url + self.loc_train_mapping_path = loc_train_mapping_url + self.loc_val_mapping_path = loc_val_mapping_url + self.cls_val_mapping_path = cls_val_mapping_url + self.val_blacklist_path = val_blacklist_url + + @staticmethod + def parse_gt_train( + samples: Iterable[Dict[str, Any]], cls_map: Dict[str, Tuple[int, str]] + ) -> Generator[Dict[str, Any], None, None]: + """Add class id, class index, and class label to training set samples. + The class id, index, and label are stored under the keys "class_id", "classification_label", + "classification_label_name", respectively. + """ + for sample in samples: + class_id = sample["url"].split("/")[-2] + label, label_name = cls_map[class_id] + sample["classification_label"] = label + sample["classification_label_name"] = label_name + sample["class_id"] = class_id + yield sample + + @staticmethod + def parse_bbox( + samples: Iterable[Dict[str, Any]], cls_map: Dict[str, Tuple[int, str]], loc_map: Dict[str, List[Dict[str, Any]]] + ) -> Generator[Dict[str, Any], None, None]: + """Add bounding boxes to samples. + The list of bounding box dictionaries are stored under the key "bboxes". + """ + for sample in samples: + # split after last / and before last . + filename = Path(sample["url"]).stem + sample["bboxes"] = [] + if filename in loc_map: + a_bboxes = loc_map[filename] + for bbox in a_bboxes: + label, label_name = cls_map[bbox["class_id"]] + bbox["classification_label"] = label + bbox["classification_label_name"] = label_name + sample["bboxes"].append(bbox) + yield sample + + @staticmethod + def parse_gt_val( + samples: Iterable[Dict[str, Any]], clsidx_map: Dict[int, Tuple[str, str]], clsidx_val_list: List[int] + ) -> Generator[Dict[str, Any], None, None]: + """Add class id, class index, and class label to validation set samples. + The class id, index, and label are stored under the keys "class_id", "classification_label", + "classification_label_name", respectively. + """ + for sample in samples: + # split after last / and before last . + filename = Path(sample["url"]).stem + # extract label id from filename + a_id = int(filename.split("_")[-1]) - 1 + label = clsidx_val_list[a_id] + class_id, label_name = clsidx_map[label] + sample["class_id"] = class_id + sample["classification_label_name"] = label_name + sample["classification_label"] = label + yield sample + + @staticmethod + def filter_val_samples_by_idx( + samples: Iterable[Dict[str, Any]], idx_blacklist: Set[int] + ) -> Generator[Dict[str, Any], None, None]: + """Yield samples that are not in the blacklist.""" + for sample in samples: + a_id = int(Path(sample["url"]).stem.split("_")[-1]) - 1 + if a_id not in idx_blacklist: + yield sample + + def get_loc_mapper(self, url: str) -> Dict[str, List[Dict[str, Any]]]: + """Get dict that maps from an imagenet filename to a list of bbox dicts.""" + loc_map = dict() + with FileLoader(url).get_store(mode="rt") as f: + f.readline() # skip the header + for row in f.readlines(): + file_name, bboxes_raw = row.strip().split(",") + bboxes = [] + for bbox_raw in bboxes_raw.split("n")[1:]: + class_id, *x = bbox_raw.strip().split(" ") + assert len(x) == 4 + class_id = f"n{class_id}" + bboxes.append({"class_id": class_id, "loc": [int(ax) for ax in x]}) + loc_map[file_name] = bboxes + return loc_map + + def get_val_clsidx_list(self) -> List[int]: + """Get list of class indices for validation samples.""" + with FileLoader(self.cls_val_mapping_path).get_store(mode="rt") as f: + val_clsidx = [int(row.strip()) - 1 for row in f.readlines()] + return val_clsidx + + def get_val_blacklist_indices(self) -> Set[int]: + """Get a set of blacklisted validation sample indices.""" + with FileLoader(self.val_blacklist_path).get_store(mode="rt") as f: + blacklist_idx = {int(row.strip()) - 1 for row in f.readlines()} + return blacklist_idx + + def get_id_to_idx_and_name_mapper(self) -> Dict[str, Tuple[int, str]]: + """Get dict that maps from an imagenet foldername (i.e. class id) to a (class index, class name) tuple. + WARNING: Imagenet idx start from 1. Here we start from 0. + """ + gt_map = dict() + with FileLoader(self.cls_mapping_path).get_store(mode="rt") as f: + for row in f.readlines(): + row = row.strip() + class_id, class_idx, class_name = row.split(" ") + gt_map[class_id] = (int(class_idx) - 1, class_name) + return gt_map + + def get_idx_to_id_and_name_mapper(self) -> Dict[int, Tuple[str, str]]: + """Get dict that maps from an imagenet class index to a (class id, class name) tuple. + WARNING: Imagenet idx start from 1. Here we start from 0. + """ + cls_map = self.get_id_to_idx_and_name_mapper() + return {class_idx: (class_id, class_name) for class_id, (class_idx, class_name) in cls_map.items()} + + @staticmethod + def load_sample(sample: Dict[str, Any]) -> Dict[str, Any]: + """Load sample from dict containing url to sample.""" + sample["image"] = RawImageNet.load_image(sample["url"]) + return sample + + def get_iter( + self, + split: str, + hooks: Optional[List[Iterable]] = None, + parse: bool = True, + shuffle: bool = True, + buffer_size: int = 100_000, + **kwargs, + ) -> Composable: + """Create iterstream for the given split. + + Args: + split (str): Split name. Must be one of ("train", "val", "test"). + hooks (List[Iterable], optional): Hooks to apply. Hooks are applied before parsing the samples. Defaults to + None. + parse (bool, optional): Whether to load the image into sample dictionary. Image will be stored under the + key "image". Defaults to True. + shuffle (bool, optional): Whether to shuffle the samples. Defaults to True. + buffer_size (int, optional): Buffer size used for shuffling. Defaults to 100_000. + + Returns: + Composable: Composable containing the samples. + """ + assert split in ["train", "val", "test"] + url = os.path.join(self.path, split) + if hooks is None: + hooks = [] + + it = FilePathGenerator(url, nested=True) + + if shuffle: + it = it.shuffle(size=buffer_size, initial=buffer_size) + + it = it.map(lambda x: {"url": x}) + if self.cls_mapping_path is not None and split != "test": + cls_map = self.get_id_to_idx_and_name_mapper() + + if split == "train": + it = it.to(RawImageNet.parse_gt_train, cls_map=cls_map) + + if self.loc_train_mapping_path is not None: + loc_map = self.get_loc_mapper(self.loc_train_mapping_path) + it = it.to(RawImageNet.parse_bbox, cls_map=cls_map, loc_map=loc_map) + + elif split == "val": + if self.val_blacklist_path is not None: + val_blacklist = self.get_val_blacklist_indices() + it = it.to(RawImageNet.filter_val_samples_by_idx, idx_blacklist=val_blacklist) + + if self.cls_val_mapping_path is not None: + val_clsidx_map = self.get_val_clsidx_list() + clsidx_map = self.get_idx_to_id_and_name_mapper() + it = it.to(RawImageNet.parse_gt_val, clsidx_map=clsidx_map, clsidx_val_list=val_clsidx_map) + + if self.loc_val_mapping_path is not None: + loc_map = self.get_loc_mapper(self.loc_val_mapping_path) + it = it.to(RawImageNet.parse_bbox, cls_map=cls_map, loc_map=loc_map) + + for h in hooks: + it = it.to(h) + + if not parse: + return it + return it.map(RawImageNet.load_sample) diff --git a/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py b/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py new file mode 100644 index 0000000..17fa1e5 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py @@ -0,0 +1,39 @@ +from dataclasses import asdict, dataclass + +import hydra +from hydra.core.config_store import ConfigStore + +from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards +from squirrel_datasets_core.spark.setup_spark import get_spark + + +@dataclass +class RawImagenetIterKwargs: + split: str = "train" + parse: bool = False + shuffle: bool = True + buffer_size: int = 100_000 + + +@dataclass +class RawImagenetShardConfig(SaveShardsConfig): + identifier: str = "imagenet" + version: int = 1 + num_shards: int = 100 + output_data_url: str = "imagenet/train" + output_catalog_url: str = "imagenet/train.yaml" + iter_kwargs: RawImagenetIterKwargs = RawImagenetIterKwargs() + + +cs = ConfigStore.instance() +cs.store(name="imagenet_config", node=RawImagenetShardConfig) + + +@hydra.main(config_path=None, config_name="kaggle_casting_config") +def main(cfg: RawImagenetShardConfig) -> None: + """When running via cli create spark session yourself""" + save_shards(cfg=cfg, session=get_spark("imagenet-preprocessing"), iter_kwargs=asdict(cfg.iter_kwargs)) + + +if __name__ == "__main__": + main() diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py new file mode 100644 index 0000000..7655909 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py @@ -0,0 +1,39 @@ +import os +from itertools import chain +from typing import Callable, Dict, List, Optional + +from squirrel.driver import RecordIteratorDriver +from squirrel.iterstream import Composable, FilePathGenerator, IterableSource + +from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver + + +class RawKaggleCastingQuality(RecordIteratorDriver, TwoDImageFileDriver): + name = "raw_kaggle_casting_quality" + + def __init__(self, url: str, **kwargs) -> None: + """Init driver.""" + self.path = url + + @staticmethod + def load_sample(sample: Dict) -> Dict: + """Load sample from dict containing url to sample.""" + sample["image"] = RawKaggleCastingQuality.load_image(sample["url"]) + return sample + + def get_iter(self, split: str, hooks: Optional[List[Callable]] = None, parse: bool = True, **kwargs) -> Composable: + """Create iterstream based on dataset split (train, test). Applies hooks before loading samples.""" + assert split in ["train", "test"] # kaggle casting quality datasets only have train and test split. + if hooks is None: + hooks = [] + + samples_def = FilePathGenerator(os.path.join(self.path, split, "def_front")).map( + lambda x: {"url": x, "label": 0} + ) + samples_ok = FilePathGenerator(os.path.join(self.path, split, "ok_front")).map(lambda x: {"url": x, "label": 1}) + it = IterableSource(chain(samples_ok, samples_def)).shuffle(size=1_000_000, initial=1_000_000) + for h in hooks: + it = it.to(h) + if not parse: + return it + return it.map(RawKaggleCastingQuality.load_sample) diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py new file mode 100644 index 0000000..b217c68 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore + +from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQuality +from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards +from squirrel_datasets_core.spark.setup_spark import get_spark + + +@dataclass +class KaggleCastingShardConfig(SaveShardsConfig): + identifier: str = "kaggle_casting_quality" + version: int = 1 + num_shards: int = 100 + output_data_url: str = "kaggle_casting_quality/train" + output_catalog_url: str = "kaggle_casting_quality/train.yaml" + split: str = "train" + parse: bool = False + + +cs = ConfigStore.instance() +cs.store(name="kaggle_casting_config", node=KaggleCastingShardConfig) + + +@hydra.main(config_path=None, config_name="kaggle_casting_config") +def main(cfg: KaggleCastingShardConfig) -> None: + """When running via cli create spark session yourself""" + save_shards( + cfg=cfg, + session=get_spark("kagglecastingquality-preprocessing"), + iter_kwargs=dict(split=cfg.split, parse=cfg.parse), + hooks=[RawKaggleCastingQuality.load_sample], + ) + + +if __name__ == "__main__": + main() diff --git a/src/squirrel_datasets_core/driver/__init__.py b/src/squirrel_datasets_core/driver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/squirrel_datasets_core/driver/fsspec.py b/src/squirrel_datasets_core/driver/fsspec.py new file mode 100644 index 0000000..255d3e5 --- /dev/null +++ b/src/squirrel_datasets_core/driver/fsspec.py @@ -0,0 +1,73 @@ +"""Base driver class that utilizes fsspec to access files.""" +import os +from typing import Any, Dict, Optional + +import numpy as np +from PIL import Image +from squirrel.driver.driver import FileDriver +from squirrel.fsspec.fs import get_fs_from_url + + +class FsspecDriver(FileDriver): + """Generic driver that uses an fsspec filesystem object to load its samples.""" + + def __init__(self, url: str, **kwargs) -> None: + """Initialize FsspecDriver. + + Args: + url (str): Path to the directory containing the dataset. + **kwargs: Additional keyword arguments passed to the super class initializer. + """ + super().__init__(**kwargs) + self.url = url + + @staticmethod + def _get_store_helper(path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> Any: + """Return a handler for a file.""" + if storage_options is None: + storage_options = {} + if open_kwargs is None: + open_kwargs = {} + + fs = get_fs_from_url(path, **storage_options) + return fs.open(path, **open_kwargs) + + def get_store(self, path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> Any: + """Return a handler for a file in the dataset. + + Args: + path (str): File path relative to `self.url`. + storage_options (Dict, optional): Storage options passed to :py:meth:`squirrel.fsspec.fs.get_fs_from_url`. + open_kwargs (Dict, optional): Keyword arguments passed to `fs.open()`. By default, only mode="r" is set. + + Returns: + File handler corresponding to the given file. + """ + if open_kwargs is None: + open_kwargs = {} + open_kwargs["mode"] = open_kwargs.get("mode", "r") + + fp = os.path.join(self.url, path) + return self._get_store_helper(path=fp, storage_options=storage_options, open_kwargs=open_kwargs) + + +class TwoDImageFileDriver(FsspecDriver): + @staticmethod + def load_image(path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> np.ndarray: + """Load a 2D RGB image. + + Args: + path (str): Image path. + storage_options (Dict, optional): Storage options passed to :py:meth:`squirrel.fsspec.fs.get_fs_from_url`. + Defaults to None. + open_kwargs (Dict, optional): Keyword arguments passed to `fs.open()`. By default, only mode="rb" is set. + + Returns: + np.ndarray: Image as a numpy array. + """ + if open_kwargs is None: + open_kwargs = {} + open_kwargs["mode"] = open_kwargs.get("mode", "rb") + + with FsspecDriver._get_store_helper(path, storage_options=storage_options, open_kwargs=open_kwargs) as fh: + return np.array(Image.open(fh)) diff --git a/src/squirrel_datasets_core/preprocessing/__init__.py b/src/squirrel_datasets_core/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/squirrel_datasets_core/preprocessing/save_shards.py b/src/squirrel_datasets_core/preprocessing/save_shards.py new file mode 100644 index 0000000..2e48b98 --- /dev/null +++ b/src/squirrel_datasets_core/preprocessing/save_shards.py @@ -0,0 +1,74 @@ +import os +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Iterable, List, Optional + +from pyspark.sql import SparkSession +from squirrel.catalog import Catalog, Source +from squirrel.driver.messagepack.msgstore import MessagepackStore +from squirrel.driver.spec import Sample, Shard + + +@dataclass +class SaveShardsConfig: + identifier: str # identifier used in the catalog for the dataset + version: int # dataset version to use + num_shards: int + output_data_url: str # path where the shard will be written + output_catalog_url: str # path where the catalog will be written + num_samples: Optional[int] = None # number of samples to take from the dataset, if None, all samples will be taken + + +def save_iterable_as_shard(pt: Iterable, output_url: str) -> None: + """Helper to save a shard into a messagepack store using squirrel""" + pt = list(pt) + if len(pt) == 0: + return + store = MessagepackStore(output_url) + store.set(Shard(pt)) + + +def save_shards( + cfg: SaveShardsConfig, + session: SparkSession, + iter_kwargs: Dict, + hooks: Optional[List[Callable[[Dict], Dict]]] = None, +) -> None: + """Process a source and save it as messagepack. + + Args: + cfg (SaveShardsConfig): Config to use. + session (SparkSession): Spark session. + iter_kwargs (Dict): Keyword arguments passed to :py:meth:`Driver.get_iter`. + hooks (Optional[List[Callable[[Dict], Dict]]], optional): Methods to map to the samples before creating shards. + """ + if hooks is None: + hooks = [] + + # get raw data + cat = Catalog.from_plugins() + d = cat[cfg.identifier][cfg.version] + src_it = d.load.get_iter(**iter_kwargs) + if cfg.num_samples is not None: + src_it = src_it.take(cfg.num_samples) + + # resolve relative paths + out_url = os.path.abspath(cfg.output_data_url) + + # run preprocessing with spark + pipe = session.sparkContext.parallelize(src_it) + for h in hooks: + pipe = pipe.map(h) + pipe = pipe.map(Sample) + pipe = pipe.coalesce(cfg.num_shards) + _ = pipe.foreachPartition(partial(save_iterable_as_shard, output_url=out_url)) + + ncat = Catalog() + # set version via DummyCatalogSource + cat_source = ncat[cfg.identifier] + cat_source[cfg.version + 1] = Source( + driver_name="messagepack", + metadata=d.metadata, + driver_kwargs={"url": out_url}, + ) + ncat.to_file(cfg.output_catalog_url) diff --git a/src/squirrel_datasets_core/spark/__init__.py b/src/squirrel_datasets_core/spark/__init__.py new file mode 100644 index 0000000..a071e7d --- /dev/null +++ b/src/squirrel_datasets_core/spark/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.spark.setup_spark import get_spark, stop_spark + +__all__ = ["get_spark", "stop_spark"] diff --git a/src/squirrel_datasets_core/spark/setup_spark.py b/src/squirrel_datasets_core/spark/setup_spark.py new file mode 100644 index 0000000..d9840ac --- /dev/null +++ b/src/squirrel_datasets_core/spark/setup_spark.py @@ -0,0 +1,101 @@ +import os +import typing as t + +from pyspark import SparkConf +from pyspark.sql import SparkSession + + +def get_spark( + app_name: str, + dataset_conf: t.Dict[str, str] = None, + add_gcs_connector: bool = True, + add_hadoop: bool = False, + mem_size: str = "4g", +) -> SparkSession: + """ + Get a spark session for your operation. + + Args: + app_name: The name of your spark session. + dataset_conf: A dictionary containing configs for your runtime spark configs. Must be configs that can be + altered after the creation of the spark session, otherwise will be complained by spark. A detailed list of + viable options can be found here: https://spark.apache.org/docs/latest/configuration.html. + add_gcs_connector: If true, will add shaded gcs-connector for hadoop3 and other related jar files to spark. + It must be True in order for cloudbuild to work. + add_hadoop: If true, will add hadoop native libs to spark lib path, such that spark can use native hadoop lib + tools to compress and decompress zst files. + mem_size: Memory size for the driver node. + + Returns: + SparkSession + """ + conf = SparkConf() + conf.set("spark.executor.memory", mem_size) + conf.set("spark.jars.repositories", "https://maven.google.com/") # add google maven repo as extra jars lookup link. + # add already installed jar files into path. + _SPARK_HOME = _get_home("SPARK_HOME") + _EXTRA_LIB_PATH = f"{_SPARK_HOME}/jars" + conf.set("spark.driver.extraLibraryPath", _EXTRA_LIB_PATH) + conf.set("spark.executor.extraLibraryPath", _EXTRA_LIB_PATH) + + if add_gcs_connector: + conf = _add_gcs_connector(conf) + + if add_hadoop: + conf = _add_native_hadoop(conf, extra_lib_path=_EXTRA_LIB_PATH) + + assert isinstance(app_name, str), ValueError("`app_name` accept string only.") + spark = SparkSession.builder.appName(app_name).config(conf=conf).getOrCreate() + + if dataset_conf: + for key, value in dataset_conf.items(): + spark.conf.set(key, str(value)) + + return spark + + +def _get_home(env_var: str) -> str: + """Get home directories for e.g. spark, hadoop or other java lib homes.""" + _HOME = os.environ.get(env_var) + assert _HOME is not None, EnvironmentError(f"Could not find '{env_var}' in current OS env.") + return _HOME + + +def _add_gcs_connector(conf: SparkConf) -> SparkConf: + """Add gcs-connector and related dependencies (jar files) to spark configuration.""" + _SPARK_HOME = _get_home("SPARK_HOME") + GCS_CONNECTOR_JARS = { + "spark.jars.packages": "com.google.cloud.bigdataoss:gcs-connector:hadoop3-2.2.2", + "spark.hadoop.fs.AbstractFileSystem.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS", + "spark.hadoop.fs.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem", + "spark.hadoop.google.cloud.auth.service.account.enable": "true", + "spark.jars": f"{_SPARK_HOME}/jars/gcs-connector-hadoop3-latest.jar", + } + for key, val in GCS_CONNECTOR_JARS.items(): + conf.set(key, val) + return conf + + +def _add_native_hadoop(conf: SparkConf, extra_lib_path: str = None) -> SparkConf: + """Link extra hadoop native libs to spark configuration, such that spark can use these libraries to do some specific + IO works. Must be added when one does I/O with zstandard compressed files in spark. (The actual hadoop native libs + are pre-installed in docker image `infrastructure/docker/Dockerfile.spark`. + """ + _HADOOP_HOME = _get_home("HADOOP_HOME") + if extra_lib_path is not None: + extra_lib_path = f"{extra_lib_path}:{_HADOOP_HOME}/lib/native" + else: + extra_lib_path = f"{_HADOOP_HOME}/lib/native" + conf.set("spark.driver.extraLibraryPath", extra_lib_path) + conf.set("spark.executor.extraLibraryPath", extra_lib_path) + return conf + + +def stop_spark(spark: SparkSession) -> None: + """Stop a spark session in a graceful way.""" + spark.sparkContext._gateway.shutdown_callback_server() + spark.stop() + + +if __name__ == "__main__": + spark = get_spark("test") From d4588f9a19a0387fa59af8f9f08403ef2d971463 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 3 Feb 2022 19:35:21 +0100 Subject: [PATCH 04/19] adding dataset cards --- docs/source/dataset_links/imagenet.rst | 3 + .../dataset_links/kaggle_casting_quality.rst | 3 + docs/source/index.rst | 8 ++ .../datasets/imagenet/README.rst | 28 ++++ .../kaggle_casting_quality/README.rst | 129 ++++++++++++++++++ 5 files changed, 171 insertions(+) create mode 100644 docs/source/dataset_links/imagenet.rst create mode 100644 docs/source/dataset_links/kaggle_casting_quality.rst create mode 100644 src/squirrel_datasets_core/datasets/imagenet/README.rst create mode 100644 src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst diff --git a/docs/source/dataset_links/imagenet.rst b/docs/source/dataset_links/imagenet.rst new file mode 100644 index 0000000..9c89696 --- /dev/null +++ b/docs/source/dataset_links/imagenet.rst @@ -0,0 +1,3 @@ +Dataset Card - Imagenet +===================================== +.. include:: ../../../src/squirrel_datasets_core/datasets/imagenet/README.rst \ No newline at end of file diff --git a/docs/source/dataset_links/kaggle_casting_quality.rst b/docs/source/dataset_links/kaggle_casting_quality.rst new file mode 100644 index 0000000..8540ae1 --- /dev/null +++ b/docs/source/dataset_links/kaggle_casting_quality.rst @@ -0,0 +1,3 @@ +Dataset Card - Kaggle Casting Quality +===================================== +.. include:: ../../../src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index e270bb1..f601daa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,3 +2,11 @@ Squirrel Datasets Documentation =============================== +Datasets +-------- + +.. toctree:: + :maxdepth: 2 + + dataset_links/kaggle_casting_quality + dataset_links/imagenet \ No newline at end of file diff --git a/src/squirrel_datasets_core/datasets/imagenet/README.rst b/src/squirrel_datasets_core/datasets/imagenet/README.rst new file mode 100644 index 0000000..d47dac8 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/imagenet/README.rst @@ -0,0 +1,28 @@ +.. list-table:: Dataset properties + :header-rows: 1 + + * - pretty_name + - annotations_creators + - language_creators + - languages + - licenses + - multilinguality + - size_categories + - source_datasets + - task_categories + - task_ids + - paperswithcode_id + * - Imagenet Dataset + - + - + - [] + - + - [] + - + - + - + - image-classification + - + +Dataset Description +################### diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst new file mode 100644 index 0000000..9da91f5 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst @@ -0,0 +1,129 @@ +.. list-table:: Dataset properties + :header-rows: 1 + + * - pretty_name + - annotations_creators + - language_creators + - languages + - licenses + - multilinguality + - size_categories + - source_datasets + - task_categories + - task_ids + - paperswithcode_id + * - Kaggle Casting Quality + - + - + - [] + - CC BY-NC-ND 4.0 + - [] + - 1K`_ +* Licenses: `Attribution-NonCommercial-NoDerivatives 4.0 International `_ + +Dataset Summary +*************** + +Dataset for the automatic identification of casting defects. + +Download and prepare data +************************* + +Download the data directly from `Kaggle `_ and extract it. Replace {PATH_TO_DATA} below with the location of the data. Use the following code to load it: + +.. code-block:: python + + from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQuality + iter_train = RawKaggleCastingQuality("{PATH_TO_DATA}/casting_data/casting_data").get_iter("train") + iter_test = RawKaggleCastingQuality("{PATH_TO_DATA}/casting_data/casting_data").get_iter("test") + + +Dataset Structure +################### + +Data Instances +************** + +A sample from the training set is provided below: + +.. code-block:: + + {'url': '/home/winfried/projects/squirrel-datasets-core/src/archive/casting_data/casting_data/test/ok_front/cast_ok_0_9996.jpeg', 'label': 1, 'image': array(...)} + +Data Fields +*********** + +- `img`: A numpy array containing the 300x300 RGB image. +- `label`: `1` for ok front and `0` for defect front. + +Data Splits +*********** + ++--------------+-----+----+ +| name |train|test| ++--------------+-----+----+ +|kaggle-casting|6633 |715 | ++--------------+-----+----+ + +Dataset Creation +################ + +Curation Rationale +****************** + +[More Information Needed] + +Source Data +*********** + +Initial Data Collection and Normalization + +[More Information Needed] + +Annotations +*********** + +Annotation process + +[More Information Needed] + +Who are the annotators? + +[More Information Needed] + +Personal and Sensitive Information +********************************** + +[More Information Needed] + +Considerations for Using the Data +#################################### + +Social Impact of Dataset +********************************** + +[More Information Needed] + +Discussion of Biases +********************************** + +[More Information Needed] + +Other Known Limitations +********************************** + +[More Information Needed] + +Citation Information +********************************** + +[More Information Needed] From 7a12841ce8cd29f1e7e9463fba322021608c0077 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 3 Feb 2022 19:55:19 +0100 Subject: [PATCH 05/19] edit data card --- .../datasets/kaggle_casting_quality/README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst index 9da91f5..fcec85c 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst @@ -59,8 +59,8 @@ A sample from the training set is provided below: {'url': '/home/winfried/projects/squirrel-datasets-core/src/archive/casting_data/casting_data/test/ok_front/cast_ok_0_9996.jpeg', 'label': 1, 'image': array(...)} -Data Fields -*********** +Dataset Schema +************** - `img`: A numpy array containing the 300x300 RGB image. - `label`: `1` for ok front and `0` for defect front. From d6126e4693679b311fd4d6244a2ba7391a37824a Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 7 Feb 2022 16:09:30 +0100 Subject: [PATCH 06/19] remove docu versioning --- docs/source/conf.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 89d4f8e..1e1928c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,10 +6,6 @@ project = 'Squirrel Datasets' copyright = f'{datetime.datetime.now().year}, Merantix Labs GmbH' author = 'Merantix Labs GmbH' - -release = '0.1' -version = '0.1.0' - # -- General configuration extensions = [ From bcaa76171ded797ff0cff895d56f4d17575029bc Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 7 Feb 2022 16:09:53 +0100 Subject: [PATCH 07/19] dataset card - clean up --- .../datasets/imagenet/README.rst | 37 ++--- .../kaggle_casting_quality/README.rst | 134 +++++++++--------- 2 files changed, 89 insertions(+), 82 deletions(-) diff --git a/src/squirrel_datasets_core/datasets/imagenet/README.rst b/src/squirrel_datasets_core/datasets/imagenet/README.rst index d47dac8..2f7cfcb 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/README.rst +++ b/src/squirrel_datasets_core/datasets/imagenet/README.rst @@ -1,28 +1,29 @@ -.. list-table:: Dataset properties +.. list-table:: :header-rows: 1 * - pretty_name - - annotations_creators - - language_creators - - languages - - licenses - - multilinguality - - size_categories - - source_datasets - - task_categories - - task_ids - - paperswithcode_id - * - Imagenet Dataset - - - - - - [] - - - - [] + - Imagenet Dataset + * - annotations_creators + - + * - language_creators + - + * - languages - + * - licenses - + * - multilinguality + - + * - size_categories - + * - source_datasets + - + * - task_categories - image-classification - - + * - task_ids + - + * - paperswithcode_id + - + Dataset Description ################### diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst index fcec85c..71279dd 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst @@ -1,28 +1,28 @@ -.. list-table:: Dataset properties +.. list-table:: :header-rows: 1 * - pretty_name - - annotations_creators - - language_creators - - languages - - licenses - - multilinguality - - size_categories - - source_datasets - - task_categories - - task_ids - - paperswithcode_id - * - Kaggle Casting Quality - - - - + - Kaggle Casting Quality + * - annotations_creators + - + * - language_creators + - + * - languages - [] + * - licenses - CC BY-NC-ND 4.0 - - [] + * - multilinguality + - + * - size_categories - 1K Date: Mon, 7 Feb 2022 16:26:00 +0100 Subject: [PATCH 08/19] add relevant public datasets. Standardize naming --- .../datasets/camvid/__init__.py | 3 + .../datasets/camvid/driver.py | 93 +++++++++++++++++++ .../datasets/camvid/preprocessing.py | 46 +++++++++ .../datasets/ds_bowl_2018/__init__.py | 3 + .../datasets/ds_bowl_2018/driver.py | 76 +++++++++++++++ .../datasets/ds_bowl_2018/preprocessing.py | 46 +++++++++ .../datasets/hugging_face/__init__.py | 3 + .../datasets/hugging_face/driver.py | 30 ++++++ .../datasets/imagenet/__init__.py | 4 +- .../datasets/imagenet/driver.py | 18 ++-- .../kaggle_casting_quality/README.rst | 6 +- .../kaggle_casting_quality/__init__.py | 3 + .../datasets/kaggle_casting_quality/driver.py | 6 +- .../kaggle_casting_quality/preprocessing.py | 4 +- .../monthly_german_tweets/__init__.py | 3 + .../datasets/monthly_german_tweets/driver.py | 81 ++++++++++++++++ 16 files changed, 406 insertions(+), 19 deletions(-) create mode 100644 src/squirrel_datasets_core/datasets/camvid/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/camvid/driver.py create mode 100644 src/squirrel_datasets_core/datasets/camvid/preprocessing.py create mode 100644 src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py create mode 100644 src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py create mode 100644 src/squirrel_datasets_core/datasets/hugging_face/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/hugging_face/driver.py create mode 100644 src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py create mode 100644 src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py diff --git a/src/squirrel_datasets_core/datasets/camvid/__init__.py b/src/squirrel_datasets_core/datasets/camvid/__init__.py new file mode 100644 index 0000000..55a3f15 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/camvid/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.datasets.camvid.driver import CamvidDriver + +__all__ = ["CamvidDriver"] diff --git a/src/squirrel_datasets_core/datasets/camvid/driver.py b/src/squirrel_datasets_core/datasets/camvid/driver.py new file mode 100644 index 0000000..fa8e6c1 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/camvid/driver.py @@ -0,0 +1,93 @@ +"""Driver that can parse the Cambridge-driving Labeled Video Database (CamVid) dataset.""" +import os +from functools import partial +from typing import Callable, Dict, List, Optional + +from squirrel.driver.driver import RecordIteratorDriver +from squirrel.iterstream import Composable, FilePathGenerator, IterableSource + +from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver + +_classes = [ + "sky", + "building", + "pole", + "road", + "pavement", + "tree", + "sign_symbol", + "fence", + "car", + "pedestrian", + "bicyclist", + "unlabelled", +] +CAMVID_NAME_TO_LABEL_ID = dict(zip(_classes, range(12))) +CAMVID_LABEL_ID_TO_NAME = {idx: cls for cls, idx in CAMVID_NAME_TO_LABEL_ID.items()} + + +class CamvidDriver(RecordIteratorDriver, TwoDImageFileDriver): + """Driver that can iterate over the samples of the `Cambridge-driving Labeled Video Database (CamVid) + `_ dataset. + + The driver expects the image and label formats, directory structure and the data split used by Alex Kendall's + `SegNet tutorial https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid>`_. + + The dataset contains the following classes: sky (0), building (1), pole (2), road (3), pavement (4), tree (5), + sign/symbol (6), fence (7), car (8), pedestrian (9), bicyclist (10), unlabelled (11). + """ + + name = "camvid" + + @staticmethod + def load_sample(sample: Dict, parse_image: bool = True, parse_label: bool = True) -> Dict: + """Parse and load image and/or label into the sample dictionary. + Image and label are stored under the keys "image" and "label", respectively. + """ + if parse_image: + sample["image"] = CamvidDriver.load_image(sample["image_url"]) + if parse_label: + sample["label"] = CamvidDriver.load_image(sample["label_url"]) + return sample + + def get_iter( + self, + split: str, + hooks: Optional[List[Callable]] = None, + parse_image: bool = True, + parse_label: bool = True, + shuffle_size: int = 800, + shuffle_initial: int = 800, + ) -> Composable: + """Create iterstream for the given split. + + Args: + split (str): Split name. Must be one of ("train", "val", "test"). + hooks (List[Iterable], optional): Hooks to apply. Hooks are applied before parsing the samples. Defaults to + None. + parse_image (bool, optional): Whether to load the image into sample dictionary. Image will be stored under + the key "image". Defaults to True. + parse_label (bool, optional): Whether to load the label into sample dictionary. Label will be stored under + the key "label". Defaults to True. + shuffle_size (int, optional): Buffer size used for shuffling. Defaults to 800. + shuffle_initial (int, optional): Initial buffer size before starting to iterate. Defaults to 800. + + Returns: + Composable: Composable containing the samples. + """ + assert split in {"train", "val", "test"} + if hooks is None: + hooks = [] + + imgs_dir = os.path.join(self.url, split) + labels_dir = os.path.join(self.url, f"{split}annot") + gen = FilePathGenerator(imgs_dir).map( + lambda x: dict(image_url=x, label_url=os.path.join(labels_dir, os.path.basename(x)), split=split) + ) + it = IterableSource(gen).shuffle(size=shuffle_size, initial=shuffle_initial) + + for h in hooks: + it = it.to(h) + + load = partial(self.load_sample, parse_image=parse_image, parse_label=parse_label) + return it.map(load) diff --git a/src/squirrel_datasets_core/datasets/camvid/preprocessing.py b/src/squirrel_datasets_core/datasets/camvid/preprocessing.py new file mode 100644 index 0000000..2f39945 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/camvid/preprocessing.py @@ -0,0 +1,46 @@ +from dataclasses import asdict, dataclass + +import hydra +from hydra.core.config_store import ConfigStore + +from squirrel_datasets_core.datasets.camvid.driver import CamvidDriver +from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards +from squirrel_datasets_core.spark.setup_spark import get_spark + + +@dataclass +class IterKwargs: + split: str = "train" + parse_image: bool = False + parse_label: bool = False + shuffle_size: int = 800 + shuffle_initial: int = 800 + + +@dataclass +class CamvidShardConfig(SaveShardsConfig): + identifier: str = "camvid" + version: int = 1 + num_shards: int = 5 + output_data_url: str = "/data/camvid_shards" + output_catalog_url: str = "/data/camvid_shards/shards.yaml" + iter_kwargs: IterKwargs = IterKwargs() + + +cs = ConfigStore.instance() +cs.store(name="camvid_config", node=CamvidShardConfig) + + +@hydra.main(config_path=None, config_name="camvid_config") +def main(cfg: CamvidShardConfig) -> None: + """When running via cli create spark session yourself""" + save_shards( + cfg=cfg, + session=get_spark("camvid-preprocessing"), + iter_kwargs=asdict(cfg.iter_kwargs), + hooks=[CamvidDriver.load_sample], + ) + + +if __name__ == "__main__": + main() diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py new file mode 100644 index 0000000..e7d4333 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.datasets.ds_bowl_2018.driver import DataScienceBowl2018Driver + +__all__ = ["DataScienceBowl2018Driver"] diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py new file mode 100644 index 0000000..5e7296b --- /dev/null +++ b/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py @@ -0,0 +1,76 @@ +"""Driver that can parse the Data Science Bowl 2018 dataset.""" +import os +from functools import partial +from typing import Callable, Dict, List, Optional + +from squirrel.driver.driver import RecordIteratorDriver +from squirrel.iterstream import Composable, FilePathGenerator, IterableSource + +from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver + + +class DataScienceBowl2018Driver(RecordIteratorDriver, TwoDImageFileDriver): + """Driver that can iterate over the samples of the `2018 Data Science Bowl + `_ dataset. + + The driver expects the image and label formats, directory structure and the data split used in the competition. + Drivers allows three dataset splits: stage1_train, stage1_test, stage2_test. + """ + + name = "ds_bowl_18" + + @staticmethod + def load_sample(sample: Dict, parse_image: bool = True, parse_mask: bool = True) -> Dict: + """Parse and load image and/or groundtruth mask into the sample dictionary. + Image and masks are stored under the keys "image" and "masks", respectively. + """ + # labels are list of arrays (masks) and list of bbox coordinates + sample_url = sample["sample_url"] + if parse_image: + stem = os.path.splitext(os.path.basename(sample_url))[0] + url = os.path.join(sample_url, "images", f"{stem}.png") + sample["image"] = DataScienceBowl2018Driver.load_image(url) + + if sample["split"] == "stage1_train" and parse_mask: + gen = FilePathGenerator(os.path.join(sample_url, "masks")) + sample["masks"] = [DataScienceBowl2018Driver.load_image(url).astype("bool") for url in gen] + + return sample + + def get_iter( + self, + split: str, + hooks: Optional[List[Callable]] = None, + parse_image: bool = True, + parse_mask: bool = True, + shuffle_size: int = 800, + shuffle_initial: int = 800, + ) -> Composable: + """Create iterstream for the given split. + + Args: + split (str): Split name. Must be one of ("stage1_train", "stage1_test", "stage2_test"). + hooks (List[Iterable], optional): Hooks to apply. Hooks are applied before parsing the samples. Defaults to + None. + parse_image (bool, optional): Whether to load the image into sample dictionary. Image will be stored under + the key "image". Defaults to True. + parse_mask (bool, optional): Whether to load the instance masks into sample dictionary. Masks will be + stored under the key "masks". Only in effect for split="stage1_train". Defaults to True. + shuffle_size (int, optional): Buffer size used for shuffling. Defaults to 800. + shuffle_initial (int, optional): Initial buffer size before starting to iterate. Defaults to 800. + + Returns: + Composable: Composable containing the samples. + """ + assert split in {"stage1_train", "stage1_test", "stage2_test"} + if hooks is None: + hooks = [] + + gen = FilePathGenerator(os.path.join(self.url, split)).map(lambda x: dict(sample_url=x, split=split)) + it = IterableSource(gen).shuffle(size=shuffle_size, initial=shuffle_initial) + + for h in hooks: + it = it.to(h) + + load = partial(self.load_sample, parse_image=parse_image, parse_mask=parse_mask) + return it.map(load) diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py new file mode 100644 index 0000000..73af3f1 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py @@ -0,0 +1,46 @@ +from dataclasses import asdict, dataclass + +import hydra +from hydra.core.config_store import ConfigStore + +from squirrel_datasets_core.datasets.ds_bowl_2018.driver import DataScienceBowl2018Driver +from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards +from squirrel_datasets_core.spark.setup_spark import get_spark + + +@dataclass +class IterKwargs: + split: str = "stage1_train" + parse_image: bool = False + parse_mask: bool = False + shuffle_size: int = 50 + shuffle_initial: int = 50 + + +@dataclass +class DsBowlShardConfig(SaveShardsConfig): + identifier: str = "ds_bowl_18" + version: int = 1 + num_shards: int = 100 + output_data_url: str = "/data/ds_bowl_shards" + output_catalog_url: str = "/data/ds_bowl_shards/shards.yaml" + iter_kwargs: IterKwargs = IterKwargs() + + +cs = ConfigStore.instance() +cs.store(name="ds_bowl_18_config", node=DsBowlShardConfig) + + +@hydra.main(config_path=None, config_name="ds_bowl_18_config") +def main(cfg: DsBowlShardConfig) -> None: + """When running via cli create spark session yourself""" + save_shards( + cfg=cfg, + session=get_spark("ds-bowl-18-preprocessing"), + iter_kwargs=asdict(cfg.iter_kwargs), + hooks=[DataScienceBowl2018Driver.load_sample], + ) + + +if __name__ == "__main__": + main() diff --git a/src/squirrel_datasets_core/datasets/hugging_face/__init__.py b/src/squirrel_datasets_core/datasets/hugging_face/__init__.py new file mode 100644 index 0000000..8be3ed0 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/hugging_face/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.datasets.hugging_face.driver import HuggingfaceDriver + +__all__ = ["HuggingfaceDriver"] diff --git a/src/squirrel_datasets_core/datasets/hugging_face/driver.py b/src/squirrel_datasets_core/datasets/hugging_face/driver.py new file mode 100644 index 0000000..8c9d271 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/hugging_face/driver.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING + +from datasets import load_dataset +from squirrel.driver import RecordIteratorDriver +from squirrel.iterstream import IterableSource + +if TYPE_CHECKING: + from squirrel.catalog.catalog import Catalog + from squirrel.iterstream import Composable + + +class HuggingfaceDriver(RecordIteratorDriver): + name = "huggingface" + + def __init__( + self, name: str, subset: str = None, streaming: bool = True, catalog: Optional[Catalog] = None, **kwargs + ) -> None: + """Load huggingface dataset""" + super().__init__(catalog=catalog) + self._dataset = load_dataset(path=name, name=subset, streaming=streaming, **kwargs) + + def get_iter(self, split: str, **kwargs) -> Composable: + """Create iterstream based on dataset split.""" + return IterableSource(self._dataset[split]) + + def get_store(self, **kwargs) -> Any: + """Get the huggingface dataset object""" + return self._dataset diff --git a/src/squirrel_datasets_core/datasets/imagenet/__init__.py b/src/squirrel_datasets_core/datasets/imagenet/__init__.py index 5c84f16..82b3dd1 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/__init__.py +++ b/src/squirrel_datasets_core/datasets/imagenet/__init__.py @@ -1,3 +1,3 @@ -from squirrel_datasets_core.datasets.imagenet.driver import RawImageNet +from squirrel_datasets_core.datasets.imagenet.driver import RawImageNetDriver -__all__ = ["RawImageNet"] +__all__ = ["RawImageNetDriver"] diff --git a/src/squirrel_datasets_core/datasets/imagenet/driver.py b/src/squirrel_datasets_core/datasets/imagenet/driver.py index 705aba6..b7188e8 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/driver.py +++ b/src/squirrel_datasets_core/datasets/imagenet/driver.py @@ -9,7 +9,7 @@ from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver -class RawImageNet(RecordIteratorDriver, TwoDImageFileDriver): +class RawImageNetDriver(RecordIteratorDriver, TwoDImageFileDriver): name = "raw_imagenet" def __init__( @@ -22,7 +22,7 @@ def __init__( val_blacklist_url: Optional[str] = None, **kwargs, ) -> None: - """Init RawImageNet driver. + """Init RawImageNetDriver. Args: url (str): Path to the root directory of the dataset. @@ -163,7 +163,7 @@ def get_idx_to_id_and_name_mapper(self) -> Dict[int, Tuple[str, str]]: @staticmethod def load_sample(sample: Dict[str, Any]) -> Dict[str, Any]: """Load sample from dict containing url to sample.""" - sample["image"] = RawImageNet.load_image(sample["url"]) + sample["image"] = RawImageNetDriver.load_image(sample["url"]) return sample def get_iter( @@ -204,29 +204,29 @@ def get_iter( cls_map = self.get_id_to_idx_and_name_mapper() if split == "train": - it = it.to(RawImageNet.parse_gt_train, cls_map=cls_map) + it = it.to(RawImageNetDriver.parse_gt_train, cls_map=cls_map) if self.loc_train_mapping_path is not None: loc_map = self.get_loc_mapper(self.loc_train_mapping_path) - it = it.to(RawImageNet.parse_bbox, cls_map=cls_map, loc_map=loc_map) + it = it.to(RawImageNetDriver.parse_bbox, cls_map=cls_map, loc_map=loc_map) elif split == "val": if self.val_blacklist_path is not None: val_blacklist = self.get_val_blacklist_indices() - it = it.to(RawImageNet.filter_val_samples_by_idx, idx_blacklist=val_blacklist) + it = it.to(RawImageNetDriver.filter_val_samples_by_idx, idx_blacklist=val_blacklist) if self.cls_val_mapping_path is not None: val_clsidx_map = self.get_val_clsidx_list() clsidx_map = self.get_idx_to_id_and_name_mapper() - it = it.to(RawImageNet.parse_gt_val, clsidx_map=clsidx_map, clsidx_val_list=val_clsidx_map) + it = it.to(RawImageNetDriver.parse_gt_val, clsidx_map=clsidx_map, clsidx_val_list=val_clsidx_map) if self.loc_val_mapping_path is not None: loc_map = self.get_loc_mapper(self.loc_val_mapping_path) - it = it.to(RawImageNet.parse_bbox, cls_map=cls_map, loc_map=loc_map) + it = it.to(RawImageNetDriver.parse_bbox, cls_map=cls_map, loc_map=loc_map) for h in hooks: it = it.to(h) if not parse: return it - return it.map(RawImageNet.load_sample) + return it.map(RawImageNetDriver.load_sample) diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst index 71279dd..0aa7be9 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst @@ -43,10 +43,10 @@ Download the data directly from `Kaggle None: @@ -18,7 +18,7 @@ def __init__(self, url: str, **kwargs) -> None: @staticmethod def load_sample(sample: Dict) -> Dict: """Load sample from dict containing url to sample.""" - sample["image"] = RawKaggleCastingQuality.load_image(sample["url"]) + sample["image"] = RawKaggleCastingQualityDriver.load_image(sample["url"]) return sample def get_iter(self, split: str, hooks: Optional[List[Callable]] = None, parse: bool = True, **kwargs) -> Composable: @@ -36,4 +36,4 @@ def get_iter(self, split: str, hooks: Optional[List[Callable]] = None, parse: bo it = it.to(h) if not parse: return it - return it.map(RawKaggleCastingQuality.load_sample) + return it.map(RawKaggleCastingQualityDriver.load_sample) diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py index b217c68..1982766 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py @@ -3,7 +3,7 @@ import hydra from hydra.core.config_store import ConfigStore -from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQuality +from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQualityDriver from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards from squirrel_datasets_core.spark.setup_spark import get_spark @@ -30,7 +30,7 @@ def main(cfg: KaggleCastingShardConfig) -> None: cfg=cfg, session=get_spark("kagglecastingquality-preprocessing"), iter_kwargs=dict(split=cfg.split, parse=cfg.parse), - hooks=[RawKaggleCastingQuality.load_sample], + hooks=[RawKaggleCastingQualityDriver.load_sample], ) diff --git a/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py b/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py new file mode 100644 index 0000000..6f4c6a8 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.datasets.monthly_german_tweets.driver import MonthlyGermanTweetsDriver + +__all__ = ["MonthlyGermanTweetsDriver"] diff --git a/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py b/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py new file mode 100644 index 0000000..c26db90 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py @@ -0,0 +1,81 @@ +""" +This driver can be used to parse the Zenodo Monthly German Tweet dataset obtainable at + https://zenodo.org/record/3633935#.YcMujb1Khqt +""" +import json +from typing import Callable, Generator, List, Optional + +from squirrel.driver import RecordIteratorDriver +from squirrel.fsspec.fs import get_fs_from_url +from squirrel.iterstream import Composable, FilePathGenerator + + +class MonthlyGermanTweetsDriver(RecordIteratorDriver): + name = "raw_monthly_german_tweets" + + def __init__(self, folder: str, **kwargs) -> None: + """ + Init the iterator. + + Args: + folder: Path to the unzipped data dump + """ + self.folder = folder + self.compression = "gzip" + self.parse_error_count = 0 + + def parse_archive(self, url: str) -> List[str]: + """Parse a single archive. Note the custom parsing function due to the + special file layout. + """ + + fs = get_fs_from_url(url) + with fs.open(url, "rb", compression=self.compression) as f: + json_bytes = list(f) + + dec = "".join([b.decode("utf-8").strip() for b in json_bytes])[1:-1] + samples = ["{" + elem for elem in dec.split(",{") if not elem.startswith("{")] + return samples + + def iterate_single_archive(self, url: str) -> Generator: + """Iterate over all samples in a single archive""" + + samples = self.parse_archive(url) + + for s in samples: + try: + sample = json.loads(s) + yield sample + except json.JSONDecodeError: + self.parse_error_count += 1 + + def get_iter( + self, + key_hooks: Optional[List[Callable]] = None, + prefetch_buffer: int = 1_000, + max_workers: Optional[int] = None, + **kwargs, + ) -> Composable: + """ + Returns a composable that iterates over the raw data in a full unzipped shard of the + Monthly German Tweets dataset. + + Args: + key_hooks: List of Callables to modify the iteration over the archives in the shard + prefetch_buffer: buffer size of the samples buffer + max_workers: number of workers to use in the async_map loading. + """ + + it = FilePathGenerator(self.folder) + + if key_hooks: + for hook in key_hooks: + it = it.to(hook) + + _map = ( + it.map(self.iterate_single_archive) + if max_workers == 0 + else it.async_map(self.iterate_single_archive, prefetch_buffer, max_workers) + ) + + return _map.flatten() From a0439ba81f726f4f9e1089df805bb7720eecfbc0 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:06:54 +0100 Subject: [PATCH 09/19] update with changes from squirrel-datasets --- setup.py | 0 src/squirrel_datasets_core/__init__.py | 1 + .../datasets/camvid/__init__.py | 3 +- .../datasets/camvid/driver.py | 29 +++++-- .../datasets/ds_bowl_2018/__init__.py | 3 +- .../datasets/ds_bowl_2018/driver.py | 29 +++++-- .../datasets/hugging_face/__init__.py | 3 - .../datasets/hugging_face/driver.py | 30 ------- .../datasets/imagenet/__init__.py | 3 +- .../datasets/imagenet/driver.py | 26 ++++--- .../kaggle_casting_quality/__init__.py | 3 +- .../datasets/kaggle_casting_quality/driver.py | 31 +++++--- .../monthly_german_tweets/__init__.py | 3 +- .../datasets/monthly_german_tweets/driver.py | 58 ++++++-------- src/squirrel_datasets_core/driver/__init__.py | 5 ++ src/squirrel_datasets_core/driver/fsspec.py | 73 ----------------- src/squirrel_datasets_core/driver/hub.py | 47 +++++++++++ .../driver/huggingface.py | 34 ++++++++ .../driver/torchvision.py | 78 +++++++++++++++++++ src/squirrel_datasets_core/io/__init__.py | 3 + src/squirrel_datasets_core/io/io.py | 32 ++++++++ .../preprocessing/save_shards.py | 26 ++++--- src/squirrel_datasets_core/squirrel_plugin.py | 26 +++++++ src/test.ipynb | 42 ++++++++++ 24 files changed, 396 insertions(+), 192 deletions(-) create mode 100644 setup.py create mode 100644 src/squirrel_datasets_core/__init__.py delete mode 100644 src/squirrel_datasets_core/datasets/hugging_face/__init__.py delete mode 100644 src/squirrel_datasets_core/datasets/hugging_face/driver.py delete mode 100644 src/squirrel_datasets_core/driver/fsspec.py create mode 100644 src/squirrel_datasets_core/driver/hub.py create mode 100644 src/squirrel_datasets_core/driver/huggingface.py create mode 100644 src/squirrel_datasets_core/driver/torchvision.py create mode 100644 src/squirrel_datasets_core/io/__init__.py create mode 100644 src/squirrel_datasets_core/io/io.py create mode 100644 src/squirrel_datasets_core/squirrel_plugin.py create mode 100644 src/test.ipynb diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e69de29 diff --git a/src/squirrel_datasets_core/__init__.py b/src/squirrel_datasets_core/__init__.py new file mode 100644 index 0000000..b3c06d4 --- /dev/null +++ b/src/squirrel_datasets_core/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" \ No newline at end of file diff --git a/src/squirrel_datasets_core/datasets/camvid/__init__.py b/src/squirrel_datasets_core/datasets/camvid/__init__.py index 55a3f15..032e6b0 100644 --- a/src/squirrel_datasets_core/datasets/camvid/__init__.py +++ b/src/squirrel_datasets_core/datasets/camvid/__init__.py @@ -1,3 +1,4 @@ from squirrel_datasets_core.datasets.camvid.driver import CamvidDriver -__all__ = ["CamvidDriver"] +__all__ = ["CamvidDriver", "DRIVERS"] +DRIVERS = [CamvidDriver] diff --git a/src/squirrel_datasets_core/datasets/camvid/driver.py b/src/squirrel_datasets_core/datasets/camvid/driver.py index fa8e6c1..8c34bab 100644 --- a/src/squirrel_datasets_core/datasets/camvid/driver.py +++ b/src/squirrel_datasets_core/datasets/camvid/driver.py @@ -1,12 +1,17 @@ """Driver that can parse the Cambridge-driving Labeled Video Database (CamVid) dataset.""" +from __future__ import annotations + import os from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING + +from squirrel.driver import IterDriver +from squirrel.iterstream import FilePathGenerator, IterableSource -from squirrel.driver.driver import RecordIteratorDriver -from squirrel.iterstream import Composable, FilePathGenerator, IterableSource +from squirrel_datasets_core.io.io import load_image -from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver +if TYPE_CHECKING: + from squirrel.iterstream import Composable _classes = [ "sky", @@ -26,7 +31,7 @@ CAMVID_LABEL_ID_TO_NAME = {idx: cls for cls, idx in CAMVID_NAME_TO_LABEL_ID.items()} -class CamvidDriver(RecordIteratorDriver, TwoDImageFileDriver): +class CamvidDriver(IterDriver): """Driver that can iterate over the samples of the `Cambridge-driving Labeled Video Database (CamVid) `_ dataset. @@ -39,15 +44,25 @@ class CamvidDriver(RecordIteratorDriver, TwoDImageFileDriver): name = "camvid" + def __init__(self, url: str, **kwargs) -> None: + """Initializes the CamvidDriver. + + Args: + url (str): Path to the directory containing the dataset. + **kwargs: Other keyword arguments passes to super class initializer. + """ + super().__init__(**kwargs) + self.url = url + @staticmethod def load_sample(sample: Dict, parse_image: bool = True, parse_label: bool = True) -> Dict: """Parse and load image and/or label into the sample dictionary. Image and label are stored under the keys "image" and "label", respectively. """ if parse_image: - sample["image"] = CamvidDriver.load_image(sample["image_url"]) + sample["image"] = load_image(sample["image_url"]) if parse_label: - sample["label"] = CamvidDriver.load_image(sample["label_url"]) + sample["label"] = load_image(sample["label_url"]) return sample def get_iter( diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py index e7d4333..40569b1 100644 --- a/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py +++ b/src/squirrel_datasets_core/datasets/ds_bowl_2018/__init__.py @@ -1,3 +1,4 @@ from squirrel_datasets_core.datasets.ds_bowl_2018.driver import DataScienceBowl2018Driver -__all__ = ["DataScienceBowl2018Driver"] +__all__ = ["DataScienceBowl2018Driver", "DRIVERS"] +DRIVERS = [DataScienceBowl2018Driver] diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py index 5e7296b..25ea7b9 100644 --- a/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py +++ b/src/squirrel_datasets_core/datasets/ds_bowl_2018/driver.py @@ -1,15 +1,20 @@ """Driver that can parse the Data Science Bowl 2018 dataset.""" +from __future__ import annotations + import os from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING + +from squirrel.driver import IterDriver +from squirrel.iterstream import FilePathGenerator, IterableSource -from squirrel.driver.driver import RecordIteratorDriver -from squirrel.iterstream import Composable, FilePathGenerator, IterableSource +from squirrel_datasets_core.io import load_image -from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver +if TYPE_CHECKING: + from squirrel.iterstream import Composable -class DataScienceBowl2018Driver(RecordIteratorDriver, TwoDImageFileDriver): +class DataScienceBowl2018Driver(IterDriver): """Driver that can iterate over the samples of the `2018 Data Science Bowl `_ dataset. @@ -19,6 +24,16 @@ class DataScienceBowl2018Driver(RecordIteratorDriver, TwoDImageFileDriver): name = "ds_bowl_18" + def __init__(self, url: str, **kwargs) -> None: + """Initializes the DataScienceBowl2018Driver. + + Args: + url (str): Path to the directory containing the dataset. + **kwargs: Other keyword arguments passes to super class initializer. + """ + super().__init__(**kwargs) + self.url = url + @staticmethod def load_sample(sample: Dict, parse_image: bool = True, parse_mask: bool = True) -> Dict: """Parse and load image and/or groundtruth mask into the sample dictionary. @@ -29,11 +44,11 @@ def load_sample(sample: Dict, parse_image: bool = True, parse_mask: bool = True) if parse_image: stem = os.path.splitext(os.path.basename(sample_url))[0] url = os.path.join(sample_url, "images", f"{stem}.png") - sample["image"] = DataScienceBowl2018Driver.load_image(url) + sample["image"] = load_image(url) if sample["split"] == "stage1_train" and parse_mask: gen = FilePathGenerator(os.path.join(sample_url, "masks")) - sample["masks"] = [DataScienceBowl2018Driver.load_image(url).astype("bool") for url in gen] + sample["masks"] = [load_image(url).astype("bool") for url in gen] return sample diff --git a/src/squirrel_datasets_core/datasets/hugging_face/__init__.py b/src/squirrel_datasets_core/datasets/hugging_face/__init__.py deleted file mode 100644 index 8be3ed0..0000000 --- a/src/squirrel_datasets_core/datasets/hugging_face/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from squirrel_datasets_core.datasets.hugging_face.driver import HuggingfaceDriver - -__all__ = ["HuggingfaceDriver"] diff --git a/src/squirrel_datasets_core/datasets/hugging_face/driver.py b/src/squirrel_datasets_core/datasets/hugging_face/driver.py deleted file mode 100644 index 8c9d271..0000000 --- a/src/squirrel_datasets_core/datasets/hugging_face/driver.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from typing import Any, Optional, TYPE_CHECKING - -from datasets import load_dataset -from squirrel.driver import RecordIteratorDriver -from squirrel.iterstream import IterableSource - -if TYPE_CHECKING: - from squirrel.catalog.catalog import Catalog - from squirrel.iterstream import Composable - - -class HuggingfaceDriver(RecordIteratorDriver): - name = "huggingface" - - def __init__( - self, name: str, subset: str = None, streaming: bool = True, catalog: Optional[Catalog] = None, **kwargs - ) -> None: - """Load huggingface dataset""" - super().__init__(catalog=catalog) - self._dataset = load_dataset(path=name, name=subset, streaming=streaming, **kwargs) - - def get_iter(self, split: str, **kwargs) -> Composable: - """Create iterstream based on dataset split.""" - return IterableSource(self._dataset[split]) - - def get_store(self, **kwargs) -> Any: - """Get the huggingface dataset object""" - return self._dataset diff --git a/src/squirrel_datasets_core/datasets/imagenet/__init__.py b/src/squirrel_datasets_core/datasets/imagenet/__init__.py index 82b3dd1..a0c8b5d 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/__init__.py +++ b/src/squirrel_datasets_core/datasets/imagenet/__init__.py @@ -1,3 +1,4 @@ from squirrel_datasets_core.datasets.imagenet.driver import RawImageNetDriver -__all__ = ["RawImageNetDriver"] +__all__ = ["RawImageNetDriver", "DRIVERS"] +DRIVERS = [RawImageNetDriver] diff --git a/src/squirrel_datasets_core/datasets/imagenet/driver.py b/src/squirrel_datasets_core/datasets/imagenet/driver.py index b7188e8..3d9f406 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/driver.py +++ b/src/squirrel_datasets_core/datasets/imagenet/driver.py @@ -1,15 +1,19 @@ +from __future__ import annotations + import os from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, TYPE_CHECKING, Tuple + +from squirrel.driver import FileDriver, IterDriver +from squirrel.iterstream import FilePathGenerator -from squirrel.driver import RecordIteratorDriver -from squirrel.driver.file.dataloader import FileLoader -from squirrel.iterstream import Composable, FilePathGenerator +from squirrel_datasets_core.io import load_image -from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver +if TYPE_CHECKING: + from squirrel.iterstream import Composable -class RawImageNetDriver(RecordIteratorDriver, TwoDImageFileDriver): +class RawImageNetDriver(IterDriver): name = "raw_imagenet" def __init__( @@ -116,7 +120,7 @@ def filter_val_samples_by_idx( def get_loc_mapper(self, url: str) -> Dict[str, List[Dict[str, Any]]]: """Get dict that maps from an imagenet filename to a list of bbox dicts.""" loc_map = dict() - with FileLoader(url).get_store(mode="rt") as f: + with FileDriver(url).open(mode="rt") as f: f.readline() # skip the header for row in f.readlines(): file_name, bboxes_raw = row.strip().split(",") @@ -131,13 +135,13 @@ def get_loc_mapper(self, url: str) -> Dict[str, List[Dict[str, Any]]]: def get_val_clsidx_list(self) -> List[int]: """Get list of class indices for validation samples.""" - with FileLoader(self.cls_val_mapping_path).get_store(mode="rt") as f: + with FileDriver(self.cls_val_mapping_path).open(mode="rt") as f: val_clsidx = [int(row.strip()) - 1 for row in f.readlines()] return val_clsidx def get_val_blacklist_indices(self) -> Set[int]: """Get a set of blacklisted validation sample indices.""" - with FileLoader(self.val_blacklist_path).get_store(mode="rt") as f: + with FileDriver(self.val_blacklist_path).open(mode="rt") as f: blacklist_idx = {int(row.strip()) - 1 for row in f.readlines()} return blacklist_idx @@ -146,7 +150,7 @@ def get_id_to_idx_and_name_mapper(self) -> Dict[str, Tuple[int, str]]: WARNING: Imagenet idx start from 1. Here we start from 0. """ gt_map = dict() - with FileLoader(self.cls_mapping_path).get_store(mode="rt") as f: + with FileDriver(self.cls_mapping_path).open(mode="rt") as f: for row in f.readlines(): row = row.strip() class_id, class_idx, class_name = row.split(" ") @@ -163,7 +167,7 @@ def get_idx_to_id_and_name_mapper(self) -> Dict[int, Tuple[str, str]]: @staticmethod def load_sample(sample: Dict[str, Any]) -> Dict[str, Any]: """Load sample from dict containing url to sample.""" - sample["image"] = RawImageNetDriver.load_image(sample["url"]) + sample["image"] = load_image(sample["url"]) return sample def get_iter( diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py index 278aef3..38001b6 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py @@ -1,3 +1,4 @@ from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQualityDriver -__all__ = ["RawKaggleCastingQualityDriver"] +__all__ = ["RawKaggleCastingQualityDriver", "DRIVERS"] +DRIVERS = [RawKaggleCastingQualityDriver] diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py index 8fb716b..57fc9b2 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/driver.py @@ -1,24 +1,35 @@ +from __future__ import annotations + import os from itertools import chain -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING + +from squirrel.driver import IterDriver +from squirrel.iterstream import FilePathGenerator, IterableSource -from squirrel.driver import RecordIteratorDriver -from squirrel.iterstream import Composable, FilePathGenerator, IterableSource +from squirrel_datasets_core.io import load_image -from squirrel_datasets_core.driver.fsspec import TwoDImageFileDriver +if TYPE_CHECKING: + from squirrel.iterstream import Composable -class RawKaggleCastingQualityDriver(RecordIteratorDriver, TwoDImageFileDriver): +class RawKaggleCastingQualityDriver(IterDriver): name = "raw_kaggle_casting_quality" def __init__(self, url: str, **kwargs) -> None: - """Init driver.""" - self.path = url + """Initializes the RawKaggleCastingQualityDriver. + + Args: + url (str): Path to the directory containing the dataset. + **kwargs: Other keyword arguments passes to super class initializer. + """ + super().__init__(**kwargs) + self.url = url @staticmethod def load_sample(sample: Dict) -> Dict: """Load sample from dict containing url to sample.""" - sample["image"] = RawKaggleCastingQualityDriver.load_image(sample["url"]) + sample["image"] = load_image(sample["url"]) return sample def get_iter(self, split: str, hooks: Optional[List[Callable]] = None, parse: bool = True, **kwargs) -> Composable: @@ -27,10 +38,10 @@ def get_iter(self, split: str, hooks: Optional[List[Callable]] = None, parse: bo if hooks is None: hooks = [] - samples_def = FilePathGenerator(os.path.join(self.path, split, "def_front")).map( + samples_def = FilePathGenerator(os.path.join(self.url, split, "def_front")).map( lambda x: {"url": x, "label": 0} ) - samples_ok = FilePathGenerator(os.path.join(self.path, split, "ok_front")).map(lambda x: {"url": x, "label": 1}) + samples_ok = FilePathGenerator(os.path.join(self.url, split, "ok_front")).map(lambda x: {"url": x, "label": 1}) it = IterableSource(chain(samples_ok, samples_def)).shuffle(size=1_000_000, initial=1_000_000) for h in hooks: it = it.to(h) diff --git a/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py b/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py index 6f4c6a8..31815dc 100644 --- a/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py +++ b/src/squirrel_datasets_core/datasets/monthly_german_tweets/__init__.py @@ -1,3 +1,4 @@ from squirrel_datasets_core.datasets.monthly_german_tweets.driver import MonthlyGermanTweetsDriver -__all__ = ["MonthlyGermanTweetsDriver"] +__all__ = ["MonthlyGermanTweetsDriver", "DRIVERS"] +DRIVERS = [MonthlyGermanTweetsDriver] diff --git a/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py b/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py index c26db90..4c4fb0b 100644 --- a/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py +++ b/src/squirrel_datasets_core/datasets/monthly_german_tweets/driver.py @@ -2,20 +2,24 @@ This driver can be used to parse the Zenodo Monthly German Tweet dataset obtainable at https://zenodo.org/record/3633935#.YcMujb1Khqt """ +from __future__ import annotations + import json -from typing import Callable, Generator, List, Optional +from typing import Iterable, Iterator, List, TYPE_CHECKING -from squirrel.driver import RecordIteratorDriver +from squirrel.driver import MapDriver from squirrel.fsspec.fs import get_fs_from_url -from squirrel.iterstream import Composable, FilePathGenerator +from squirrel.iterstream import FilePathGenerator + +if TYPE_CHECKING: + from squirrel.iterstream import Composable -class MonthlyGermanTweetsDriver(RecordIteratorDriver): +class MonthlyGermanTweetsDriver(MapDriver): name = "raw_monthly_german_tweets" def __init__(self, folder: str, **kwargs) -> None: - """ - Init the iterator. + """Init the MonthlyGermanTweets driver. Args: folder: Path to the unzipped data dump @@ -25,9 +29,7 @@ def __init__(self, folder: str, **kwargs) -> None: self.parse_error_count = 0 def parse_archive(self, url: str) -> List[str]: - """Parse a single archive. Note the custom parsing function due to the - special file layout. - """ + """Parse a single archive. Note the custom parsing function due to the special file layout.""" fs = get_fs_from_url(url) with fs.open(url, "rb", compression=self.compression) as f: @@ -37,8 +39,8 @@ def parse_archive(self, url: str) -> List[str]: samples = ["{" + elem for elem in dec.split(",{") if not elem.startswith("{")] return samples - def iterate_single_archive(self, url: str) -> Generator: - """Iterate over all samples in a single archive""" + def get(self, url: str) -> Iterator: + """Yields all samples in a single archive.""" samples = self.parse_archive(url) @@ -49,33 +51,17 @@ def iterate_single_archive(self, url: str) -> Generator: except json.JSONDecodeError: self.parse_error_count += 1 - def get_iter( - self, - key_hooks: Optional[List[Callable]] = None, - prefetch_buffer: int = 1_000, - max_workers: Optional[int] = None, - **kwargs, - ) -> Composable: - """ - Returns a composable that iterates over the raw data in a full unzipped shard of the - Monthly German Tweets dataset. + def get_iter(self, flatten: bool = True, **kwargs) -> Composable: + """Returns a composable that iterates over the raw data in a full unzipped shard of the Monthly German Tweets + dataset. Args: - key_hooks: List of Callables to modify the iteration over the archives in the shard - prefetch_buffer: buffer size of the samples buffer - max_workers: number of workers to use in the async_map loading. + flatten (bool): Whether to flatten the returned iterable. Defaults to False. + **kwargs: Other keyword arguments passed to :py:meth:`MapDriver.get_iter`. """ - it = FilePathGenerator(self.folder) - - if key_hooks: - for hook in key_hooks: - it = it.to(hook) - - _map = ( - it.map(self.iterate_single_archive) - if max_workers == 0 - else it.async_map(self.iterate_single_archive, prefetch_buffer, max_workers) - ) + super().get_iter(flatten=flatten, **kwargs) - return _map.flatten() + def keys(self, **kwargs) -> Iterable: + """Returns the paths of the files in the root directory relative to root.""" + return FilePathGenerator(self.folder) diff --git a/src/squirrel_datasets_core/driver/__init__.py b/src/squirrel_datasets_core/driver/__init__.py index e69de29..272a51d 100644 --- a/src/squirrel_datasets_core/driver/__init__.py +++ b/src/squirrel_datasets_core/driver/__init__.py @@ -0,0 +1,5 @@ +from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver +from squirrel_datasets_core.driver.torchvision import TorchvisionDriver +from squirrel_datasets_core.driver.hub import HubDriver + +__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver"] diff --git a/src/squirrel_datasets_core/driver/fsspec.py b/src/squirrel_datasets_core/driver/fsspec.py deleted file mode 100644 index 255d3e5..0000000 --- a/src/squirrel_datasets_core/driver/fsspec.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Base driver class that utilizes fsspec to access files.""" -import os -from typing import Any, Dict, Optional - -import numpy as np -from PIL import Image -from squirrel.driver.driver import FileDriver -from squirrel.fsspec.fs import get_fs_from_url - - -class FsspecDriver(FileDriver): - """Generic driver that uses an fsspec filesystem object to load its samples.""" - - def __init__(self, url: str, **kwargs) -> None: - """Initialize FsspecDriver. - - Args: - url (str): Path to the directory containing the dataset. - **kwargs: Additional keyword arguments passed to the super class initializer. - """ - super().__init__(**kwargs) - self.url = url - - @staticmethod - def _get_store_helper(path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> Any: - """Return a handler for a file.""" - if storage_options is None: - storage_options = {} - if open_kwargs is None: - open_kwargs = {} - - fs = get_fs_from_url(path, **storage_options) - return fs.open(path, **open_kwargs) - - def get_store(self, path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> Any: - """Return a handler for a file in the dataset. - - Args: - path (str): File path relative to `self.url`. - storage_options (Dict, optional): Storage options passed to :py:meth:`squirrel.fsspec.fs.get_fs_from_url`. - open_kwargs (Dict, optional): Keyword arguments passed to `fs.open()`. By default, only mode="r" is set. - - Returns: - File handler corresponding to the given file. - """ - if open_kwargs is None: - open_kwargs = {} - open_kwargs["mode"] = open_kwargs.get("mode", "r") - - fp = os.path.join(self.url, path) - return self._get_store_helper(path=fp, storage_options=storage_options, open_kwargs=open_kwargs) - - -class TwoDImageFileDriver(FsspecDriver): - @staticmethod - def load_image(path: str, storage_options: Optional[Dict] = None, open_kwargs: Optional[Dict] = None) -> np.ndarray: - """Load a 2D RGB image. - - Args: - path (str): Image path. - storage_options (Dict, optional): Storage options passed to :py:meth:`squirrel.fsspec.fs.get_fs_from_url`. - Defaults to None. - open_kwargs (Dict, optional): Keyword arguments passed to `fs.open()`. By default, only mode="rb" is set. - - Returns: - np.ndarray: Image as a numpy array. - """ - if open_kwargs is None: - open_kwargs = {} - open_kwargs["mode"] = open_kwargs.get("mode", "rb") - - with FsspecDriver._get_store_helper(path, storage_options=storage_options, open_kwargs=open_kwargs) as fh: - return np.array(Image.open(fh)) diff --git a/src/squirrel_datasets_core/driver/hub.py b/src/squirrel_datasets_core/driver/hub.py new file mode 100644 index 0000000..5a8cd75 --- /dev/null +++ b/src/squirrel_datasets_core/driver/hub.py @@ -0,0 +1,47 @@ +"""Driver for hub datasets.""" +from typing import Optional + +import hub +from squirrel.driver import IterDriver +from squirrel.iterstream import Composable, IterableSource + + +class HubDriver(IterDriver): + """Driver to access hub datasets.""" + + name = "hub" + + def __init__(self, url: str, **kwargs) -> None: + """Initialize HubDriver. + + Args: + url (str): Path to the directory containing the dataset. + **kwargs: Additional keyword arguments passed to the super class initializer. + """ + super().__init__(**kwargs) + self.url = url + + def get_iter(self, subset: Optional[str] = None, **dataset_kwargs) -> Composable: + """Return an iterstream on top of a hub dataset. + + Args: + subset (str): Subset to open. If not provided, all subsets will be used. + **dataset_kwargs: Keyword arguments passed to :py:func:`hub.dataset`. + + Returns: + Composable + """ + with self.get_hub(**dataset_kwargs) as ds: + it = IterableSource(ds) if subset is None else IterableSource(ds[subset]) + return it.map(lambda x: x.tensors) + + def get_hub(self, **dataset_kwargs) -> hub.Dataset: + """Return a handler for the dataset. + + Args: + **dataset_kwargs: Keyword arguments passed to :py:func:`hub.dataset`. + + Returns: + Hub dataset. + """ + return hub.dataset(path=self.url, **dataset_kwargs) diff --git a/src/squirrel_datasets_core/driver/huggingface.py b/src/squirrel_datasets_core/driver/huggingface.py new file mode 100644 index 0000000..c820d09 --- /dev/null +++ b/src/squirrel_datasets_core/driver/huggingface.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +import datasets +from datasets import load_dataset +from squirrel.driver import IterDriver +from squirrel.iterstream import IterableSource + +if TYPE_CHECKING: + from squirrel.catalog.catalog import Catalog + +# Disable TQDM bars +datasets.logging.set_verbosity(datasets.logging.WARNING) + + +class HuggingfaceDriver(IterDriver): + name = "huggingface" + + def __init__( + self, + name: str, + subset: Optional[str] = None, + streaming: bool = True, + catalog: Optional[Catalog] = None, + **kwargs + ) -> None: + """Load huggingface dataset.""" + super().__init__(catalog=catalog) + self._dataset = load_dataset(path=name, name=subset, streaming=streaming, **kwargs) + + def get_iter(self, split: str) -> IterableSource: + """Create iterstream based on dataset split.""" + return IterableSource(self._dataset[split]) diff --git a/src/squirrel_datasets_core/driver/torchvision.py b/src/squirrel_datasets_core/driver/torchvision.py new file mode 100644 index 0000000..fd70bc9 --- /dev/null +++ b/src/squirrel_datasets_core/driver/torchvision.py @@ -0,0 +1,78 @@ +import inspect +import logging +from tempfile import gettempdir +from typing import Any, Optional, Type + +import torchvision +from squirrel.driver import IterDriver +from squirrel.iterstream import IterableSource + +logger = logging.getLogger(__name__) + +TORCH_DATASET_NAMES_AND_CLASSES = {x[0]: x[1] for x in inspect.getmembers(torchvision.datasets, inspect.isclass)} + +TORCH_DATASET_DOC = "https://pytorch.org/vision/stable/datasets.html" + +__all__ = ["TorchvisionDriver"] + + +class TorchvisionDriver(IterDriver): + name = "torchvision" + + def __init__(self, name: str, url: Optional[str] = None, download: Optional[bool] = None, **open_kwargs) -> None: + """Initialize TorchVisionLoader. + + Args: + url (str): Path to the directory containing the dataset (or the directory where the dataset will be + downloaded, which applies when `download` is set to be `True`). + If not given, will use a temp dir created from tempfile package instead. + name (str): The name of the `torch.utils.dataset.Dataset` class to be used to load the dataset. + Upper case and lower case ambiguity is allowed. + download (str): If not None, will pass to subclasses of `torchvision.datasets.DatasetFolder` which are + listed under https://pytorch.org/vision/stable/datasets.html. If True, will assert if self.url + already has the dataset, if not, will start downloading. If False, torch will assume the dataset + is already saved under `self.url`. The `None` option applies to torch classes that do not have the + download argument implemented (mostly because not all datasets can be downloaded by anonymous callers.) + **kwargs: Additional keyword arguments passed to the super class initializer. + """ + super().__init__(**open_kwargs) + self.url = ( + url if url is not None else f"{gettempdir()}/{name}" + ) # gettempdir returns a constant result throughout a session. Add a suffix to differentiate diff datasets. + self.name = name + self.download = download + + def get_iter(self, **open_kwargs) -> IterableSource: + """Load a torchvision dataset and pass them into iterstream. For a full list of arguments for each + torchvision.datasets class, see https://pytorch.org/vision/stable/datasets.html. + + Returns: + An instance of :py:class:`IterableSource`. Notice each raw sample from torchvision follows such + format: Tuple[img, Any] where the img is a PIL opened image object and Any could be class_id, annotations + or other types of metadata for a dataset. Please refer to + https://pytorch.org/vision/stable/datasets.html for a detailed description of what sample format you will + get. + """ + ds = self.get_torchvision_dataset(**open_kwargs) + return IterableSource(ds) + + def get_torchvision_dataset(self, **open_kwargs) -> Any: + """Return a handler for the dataset.""" + if open_kwargs is None: + open_kwargs = {} + + dataset = self._find_torchvision_dataset_class(self.name) + if self.download is not None: + return dataset(root=self.url, download=self.download, **open_kwargs) + else: + return dataset(root=self.url, **open_kwargs) + + @staticmethod + def _find_torchvision_dataset_class(name: str) -> Type[torchvision.datasets.VisionDataset]: + """Try to find the `torchvision.dataset` class given the input name. Allow upper and lower case ambiguity. + Raise an value error if the name could not be found in the implemented list of `torchvision.dataset` classes. + """ + for key in TORCH_DATASET_NAMES_AND_CLASSES: + if key.lower() == name.lower(): + return TORCH_DATASET_NAMES_AND_CLASSES[key] + raise ValueError(f"Dataset {name} could not be found in torchvision. See {TORCH_DATASET_DOC} for a full list.") diff --git a/src/squirrel_datasets_core/io/__init__.py b/src/squirrel_datasets_core/io/__init__.py new file mode 100644 index 0000000..2f803f8 --- /dev/null +++ b/src/squirrel_datasets_core/io/__init__.py @@ -0,0 +1,3 @@ +from squirrel_datasets_core.io.io import load_image + +__all__ = ["load_image"] diff --git a/src/squirrel_datasets_core/io/io.py b/src/squirrel_datasets_core/io/io.py new file mode 100644 index 0000000..f2e8446 --- /dev/null +++ b/src/squirrel_datasets_core/io/io.py @@ -0,0 +1,32 @@ +from typing import Dict, Optional + +import fsspec +import numpy as np +from PIL import Image + + +def load_image( + path: str, fs: Optional[fsspec.AbstractFileSystem] = None, open_kwargs: Optional[Dict] = None +) -> np.ndarray: + """Load an image. + + File is opened from an arbitrary path using fsspec and image is loaded using PIL and converted to a numpy array. + + Args: + path (str): Image path. + fs (fsspec.AbstractFileSystem, optional): Filesystem object to use for opening the file. If not provided, + :py:func:`fsspec.open` will be used. Defaults to None. + open_kwargs (Dict, optional): Keyword arguments passed to `fs.open()`. By default, only mode="rb" is set. + + Returns: + np.ndarray: Image as a numpy array. + """ + if open_kwargs is None: + open_kwargs = {} + open_kwargs["mode"] = open_kwargs.get("mode", "rb") + + if fs is None: + fs = fsspec + + with fs.open(path, **open_kwargs) as fh: + return np.array(Image.open(fh)) diff --git a/src/squirrel_datasets_core/preprocessing/save_shards.py b/src/squirrel_datasets_core/preprocessing/save_shards.py index 2e48b98..ebab1e4 100644 --- a/src/squirrel_datasets_core/preprocessing/save_shards.py +++ b/src/squirrel_datasets_core/preprocessing/save_shards.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import os from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional, TYPE_CHECKING from pyspark.sql import SparkSession from squirrel.catalog import Catalog, Source -from squirrel.driver.messagepack.msgstore import MessagepackStore -from squirrel.driver.spec import Sample, Shard +from squirrel.serialization import MessagepackSerializer +from squirrel.store import SquirrelStore + +if TYPE_CHECKING: + from squirrel.constants import ShardType @dataclass @@ -19,13 +24,15 @@ class SaveShardsConfig: num_samples: Optional[int] = None # number of samples to take from the dataset, if None, all samples will be taken -def save_iterable_as_shard(pt: Iterable, output_url: str) -> None: +def save_iterable_as_shard(shard_it: Iterable[ShardType], output_url: str) -> None: """Helper to save a shard into a messagepack store using squirrel""" - pt = list(pt) - if len(pt) == 0: - return - store = MessagepackStore(output_url) - store.set(Shard(pt)) + # only initialize store if really needed + store = None + for shard_id, shard in enumerate(shard_it): + if store is None: + store = SquirrelStore(output_url, serializer=MessagepackSerializer()) + + store.set(key=f"shard_{shard_id}", value=shard) def save_shards( @@ -59,7 +66,6 @@ def save_shards( pipe = session.sparkContext.parallelize(src_it) for h in hooks: pipe = pipe.map(h) - pipe = pipe.map(Sample) pipe = pipe.coalesce(cfg.num_shards) _ = pipe.foreachPartition(partial(save_iterable_as_shard, output_url=out_url)) diff --git a/src/squirrel_datasets_core/squirrel_plugin.py b/src/squirrel_datasets_core/squirrel_plugin.py new file mode 100644 index 0000000..6b6e6c4 --- /dev/null +++ b/src/squirrel_datasets_core/squirrel_plugin.py @@ -0,0 +1,26 @@ +import importlib +import pkgutil +from typing import List, Tuple, Type + +from squirrel.catalog import CatalogKey, Source +from squirrel.driver import Driver +from squirrel.framework.plugins.hookimpl import hookimpl + + +@hookimpl +def squirrel_drivers() -> List[Type[Driver]]: + """Custom drivers added by this package.""" + import squirrel_datasets_core.datasets as ds + from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver + from squirrel_datasets_core.driver.torchvision import TorchvisionDriver + from squirrel_datasets_core.driver.hub import HubDriver + + drivers = [TorchvisionDriver, HuggingfaceDriver, HubDriver] + for m in pkgutil.iter_modules(ds.__path__): + try: + d = importlib.import_module(f"{ds.__package__}.{m.name}").DRIVERS + drivers += d + except AttributeError: + pass + + return drivers diff --git a/src/test.ipynb b/src/test.ipynb new file mode 100644 index 0000000..6f791a2 --- /dev/null +++ b/src/test.ipynb @@ -0,0 +1,42 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQuality\n", + "iter = RawKaggleCastingQuality(\"archive/casting_data/casting_data\").get_iter(\"test\")\n", + "\n", + "for i in iter:\n", + " print(i[\"label\"])" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "338cfae0a1873dbc3e6a1266258c61ca663f4d905ce797b41b2835118552379a" + }, + "kernelspec": { + "display_name": "Python 3.8.12 64-bit ('sq-env': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 3d7c99c36506b756e70098dc2d26f9d8d87029eb Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:17:23 +0100 Subject: [PATCH 10/19] using our own gitignore --- .gitignore | 154 +++++++++++++++++++++++++++++------------------------ 1 file changed, 84 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 19252fc..06e110b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,59 @@ +######## OSX ######## +*.DS_Store +.AppleDouble +.LSOverride +.xprocess/ +.Trash-0/ +.pytest_cache/ -# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks -# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks - -### JupyterNotebooks ### -# gitignore template for Jupyter Notebooks -# website: http://jupyter.org/ - -.ipynb_checkpoints -*/.ipynb_checkpoints/* - -# IPython -profile_default/ -ipython_config.py - -# Remove previous ipynb_checkpoints -# git rm -r .ipynb_checkpoints/ +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Terraform +.terraform +terraform.tfstate +terraform.tfstate.backup +.terraform.lock.hcl +.terraform.tfstate.lock.info + +# experiment loggers +mlruns +wandb + +####### Misc ####### +# Vim +*.swp +*.swo + +nohup.out + +*.mp4 +*.avi +*.wmv +*.mov +*.pdf +*.log -### Python ### +######## Python ######## # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -27,6 +64,7 @@ __pycache__/ # Distribution / packaging .Python +env/ build/ develop-eggs/ dist/ @@ -35,11 +73,10 @@ eggs/ .eggs/ lib/ lib64/ +node_modules/ parts/ sdist/ var/ -wheels/ -share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -58,17 +95,13 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ -.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*.cover -*.py,cover +*,cover .hypothesis/ -.pytest_cache/ -cover/ # Translations *.mo @@ -77,80 +110,61 @@ cover/ # Django stuff: *.log local_settings.py -db.sqlite3 -db.sqlite3-journal # Flask stuff: instance/ .webassets-cache -# Scrapy stuff: -.scrapy - # Sphinx documentation docs/_build/ # PyBuilder -.pybuilder/ target/ -# Jupyter Notebook - -# IPython +# IPython Notebook +.ipynb_checkpoints # pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock +.python-version -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff +# celery beat schedule file celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py -# Environments +# dotenv .env -.venv -env/ + +# virtualenv +.venv/ venv/ ENV/ -env.bak/ -venv.bak/ # Spyder project settings .spyderproject -.spyproject # Rope project settings .ropeproject -# mkdocs documentation -/site +######## PYCHARM ######## +.idea -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json +######## vscode ######## +.vscode +.devcontainer.json -# Pyre type checker -.pyre/ +####### LargeVis ####### +**/annoy_index_file -# pytype static type analyzer -.pytype/ +####### Web ####### +*.sass-cache +.ruby-version +/Pipfile* +/.conda-* +/yarn-error.log +bower_components/ +npm-debug.log -# Cython debug symbols -cython_debug/ +# devspace +.devspace -# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks \ No newline at end of file +# Hydra +outputs/ From 7f9dbe31b79e35cef07a5eee218d24bd5bc56241 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:26:59 +0100 Subject: [PATCH 11/19] First version of setup.py. Require Python 3.8 --- .readthedocs.yaml | 2 +- setup.py | 107 ++++++++++++++++++ src/squirrel_datasets_core/squirrel_plugin.py | 3 +- 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index f579983..1d25c7b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.8" # You can also specify other tool versions: # nodejs: "16" # rust: "1.55" diff --git a/setup.py b/setup.py index e69de29..9133040 100644 --- a/setup.py +++ b/setup.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import ast +import itertools +import os +import re +import sys + +from setuptools import find_packages, setup + +SOURCE_DIR = "squirrel_datasets_core" + +# Read package information from other files so that just one version has to be maintained. +_version_re = re.compile(r"__version__\s+=\s+(.*)") +with open("./%s/__init__.py" % SOURCE_DIR, "rb") as f: + init_contents = f.read().decode("utf-8") + + def get_var(var_name: str) -> str: + """Parsing of squirrel_datasets_core project infos defined in __init__.py""" + pattern = re.compile(r"%s\s+=\s+(.*)" % var_name) + match = pattern.search(init_contents).group(1) + return str(ast.literal_eval(match)) + + version = get_var("__version__") + +# add tag to version if provided +if "--version_tag" in sys.argv: + v_idx = sys.argv.index("--version_tag") + version = version + "." + sys.argv[v_idx + 1] + sys.argv.remove("--version_tag") + sys.argv.pop(v_idx) + +if os.path.exists("README.rst"): + with open("README.rst") as fh: + readme = fh.read() +else: + readme = "" +if os.path.exists("HISTORY.md"): + with open("HISTORY.md") as fh: + history = fh.read().replace(".. :changelog:", "") +else: + history = "" + + +def parse_req(spec: str) -> str: + """Parse package name==version out of requirements file.""" + if ";" in spec: + # remove restriction + spec, _ = [x.strip() for x in spec.split(";", 1)] + if "#" in spec: + # remove comment + spec = spec.strip().split("#")[0] + if "\\" in spec: + # remove line breaks + spec = spec.strip().split("\\")[0] + if "--hash=" in spec: + # remove line breaks + spec = spec.strip().split("--hash=")[0] + return spec + + +if os.path.exists("requirements.in"): + with open("requirements.in") as fh: + requirements = [parse_req(r) for r in fh.read().replace("\\\n", " ").split("\n") if parse_req(r) != ""] +else: + requirements = [] + +# generate extras based on requirements files +extras_require = dict() +for a_extra in ["dev", "preprocessing", "torchvision", "hub"]: + req_file = f"requirements.{a_extra}.in" + if os.path.exists(req_file): + with open(req_file) as fh: + extras_require[a_extra] = [r for r in fh.read().split("\n") if ";" not in r] + else: + extras_require[a_extra] = [] +extras_require["all"] = list(itertools.chain.from_iterable(extras_require.values())) + +if os.path.exists("scripts"): + SCRIPTS = [os.path.join("scripts", a) for a in os.listdir("scripts")] +else: + SCRIPTS = [] + +ENTRY_POINTS = {"squirrel": ["datasets_core = squirrel_datasets_core.squirrel_plugin"]} + +# Setup package using PIP +if __name__ == "__main__": + setup( + name=SOURCE_DIR, + version=version, + python_requires=">=3.8.0", + description="Squirrel public datasets collection", + long_description=f"{readme}\n\n{history}", + author="Merantix Labs GmbH", + license="", + # Needed to make jinja work and not get linting errors in the rendered file + package_dir=dict([(SOURCE_DIR, SOURCE_DIR)]), + packages=find_packages(where="./src/"), + scripts=SCRIPTS, + include_package_data=True, + install_requires=requirements, + tests_require=extras_require["dev"], + extras_require=extras_require, + entry_points=ENTRY_POINTS, + classifiers=["Public"], + ) diff --git a/src/squirrel_datasets_core/squirrel_plugin.py b/src/squirrel_datasets_core/squirrel_plugin.py index 6b6e6c4..8a6fcc5 100644 --- a/src/squirrel_datasets_core/squirrel_plugin.py +++ b/src/squirrel_datasets_core/squirrel_plugin.py @@ -1,8 +1,7 @@ import importlib import pkgutil -from typing import List, Tuple, Type +from typing import List, Type -from squirrel.catalog import CatalogKey, Source from squirrel.driver import Driver from squirrel.framework.plugins.hookimpl import hookimpl From 47f82bd3896ec98d2cd017a3c8b91934e24d2dd0 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:30:11 +0100 Subject: [PATCH 12/19] Add requirements --- requirements.dev.in | 13 +++++++++++++ requirements.in | 10 +++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 requirements.dev.in diff --git a/requirements.dev.in b/requirements.dev.in new file mode 100644 index 0000000..112e5d8 --- /dev/null +++ b/requirements.dev.in @@ -0,0 +1,13 @@ +sphinx-rtd-theme +twine +wheel +pytest>=5.3.5 +pytest-timeout +pytest-cov +pytest-xdist +Sphinx<4.0.0 +sphinx_versions +pyspark>=3.2.0 +flaky # for retry flaky tests +nbmake +jupyter diff --git a/requirements.in b/requirements.in index 4170c03..7856c0d 100644 --- a/requirements.in +++ b/requirements.in @@ -1 +1,9 @@ -sphinx-rtd-theme \ No newline at end of file +datasets +mxlabs-squirrel[gcp, zarr]>=0.9,<0.10 +numpy +pillow +torchvision +scipy +pyspark==3.2.0 +hydra-core>=1.1.0 +hub From f8c6866d9fc02ad0f5df4d72677173dd30d79144 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:40:51 +0100 Subject: [PATCH 13/19] Remove preprocessing logic. Will add separate PR --- .../datasets/camvid/preprocessing.py | 46 -------- .../datasets/ds_bowl_2018/preprocessing.py | 46 -------- .../datasets/imagenet/preprocessing.py | 39 ------- .../kaggle_casting_quality/preprocessing.py | 38 ------- .../preprocessing/__init__.py | 0 .../preprocessing/save_shards.py | 80 -------------- src/squirrel_datasets_core/spark/__init__.py | 3 - .../spark/setup_spark.py | 101 ------------------ 8 files changed, 353 deletions(-) delete mode 100644 src/squirrel_datasets_core/datasets/camvid/preprocessing.py delete mode 100644 src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py delete mode 100644 src/squirrel_datasets_core/datasets/imagenet/preprocessing.py delete mode 100644 src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py delete mode 100644 src/squirrel_datasets_core/preprocessing/__init__.py delete mode 100644 src/squirrel_datasets_core/preprocessing/save_shards.py delete mode 100644 src/squirrel_datasets_core/spark/__init__.py delete mode 100644 src/squirrel_datasets_core/spark/setup_spark.py diff --git a/src/squirrel_datasets_core/datasets/camvid/preprocessing.py b/src/squirrel_datasets_core/datasets/camvid/preprocessing.py deleted file mode 100644 index 2f39945..0000000 --- a/src/squirrel_datasets_core/datasets/camvid/preprocessing.py +++ /dev/null @@ -1,46 +0,0 @@ -from dataclasses import asdict, dataclass - -import hydra -from hydra.core.config_store import ConfigStore - -from squirrel_datasets_core.datasets.camvid.driver import CamvidDriver -from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards -from squirrel_datasets_core.spark.setup_spark import get_spark - - -@dataclass -class IterKwargs: - split: str = "train" - parse_image: bool = False - parse_label: bool = False - shuffle_size: int = 800 - shuffle_initial: int = 800 - - -@dataclass -class CamvidShardConfig(SaveShardsConfig): - identifier: str = "camvid" - version: int = 1 - num_shards: int = 5 - output_data_url: str = "/data/camvid_shards" - output_catalog_url: str = "/data/camvid_shards/shards.yaml" - iter_kwargs: IterKwargs = IterKwargs() - - -cs = ConfigStore.instance() -cs.store(name="camvid_config", node=CamvidShardConfig) - - -@hydra.main(config_path=None, config_name="camvid_config") -def main(cfg: CamvidShardConfig) -> None: - """When running via cli create spark session yourself""" - save_shards( - cfg=cfg, - session=get_spark("camvid-preprocessing"), - iter_kwargs=asdict(cfg.iter_kwargs), - hooks=[CamvidDriver.load_sample], - ) - - -if __name__ == "__main__": - main() diff --git a/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py b/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py deleted file mode 100644 index 73af3f1..0000000 --- a/src/squirrel_datasets_core/datasets/ds_bowl_2018/preprocessing.py +++ /dev/null @@ -1,46 +0,0 @@ -from dataclasses import asdict, dataclass - -import hydra -from hydra.core.config_store import ConfigStore - -from squirrel_datasets_core.datasets.ds_bowl_2018.driver import DataScienceBowl2018Driver -from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards -from squirrel_datasets_core.spark.setup_spark import get_spark - - -@dataclass -class IterKwargs: - split: str = "stage1_train" - parse_image: bool = False - parse_mask: bool = False - shuffle_size: int = 50 - shuffle_initial: int = 50 - - -@dataclass -class DsBowlShardConfig(SaveShardsConfig): - identifier: str = "ds_bowl_18" - version: int = 1 - num_shards: int = 100 - output_data_url: str = "/data/ds_bowl_shards" - output_catalog_url: str = "/data/ds_bowl_shards/shards.yaml" - iter_kwargs: IterKwargs = IterKwargs() - - -cs = ConfigStore.instance() -cs.store(name="ds_bowl_18_config", node=DsBowlShardConfig) - - -@hydra.main(config_path=None, config_name="ds_bowl_18_config") -def main(cfg: DsBowlShardConfig) -> None: - """When running via cli create spark session yourself""" - save_shards( - cfg=cfg, - session=get_spark("ds-bowl-18-preprocessing"), - iter_kwargs=asdict(cfg.iter_kwargs), - hooks=[DataScienceBowl2018Driver.load_sample], - ) - - -if __name__ == "__main__": - main() diff --git a/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py b/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py deleted file mode 100644 index 17fa1e5..0000000 --- a/src/squirrel_datasets_core/datasets/imagenet/preprocessing.py +++ /dev/null @@ -1,39 +0,0 @@ -from dataclasses import asdict, dataclass - -import hydra -from hydra.core.config_store import ConfigStore - -from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards -from squirrel_datasets_core.spark.setup_spark import get_spark - - -@dataclass -class RawImagenetIterKwargs: - split: str = "train" - parse: bool = False - shuffle: bool = True - buffer_size: int = 100_000 - - -@dataclass -class RawImagenetShardConfig(SaveShardsConfig): - identifier: str = "imagenet" - version: int = 1 - num_shards: int = 100 - output_data_url: str = "imagenet/train" - output_catalog_url: str = "imagenet/train.yaml" - iter_kwargs: RawImagenetIterKwargs = RawImagenetIterKwargs() - - -cs = ConfigStore.instance() -cs.store(name="imagenet_config", node=RawImagenetShardConfig) - - -@hydra.main(config_path=None, config_name="kaggle_casting_config") -def main(cfg: RawImagenetShardConfig) -> None: - """When running via cli create spark session yourself""" - save_shards(cfg=cfg, session=get_spark("imagenet-preprocessing"), iter_kwargs=asdict(cfg.iter_kwargs)) - - -if __name__ == "__main__": - main() diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py deleted file mode 100644 index 1982766..0000000 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/preprocessing.py +++ /dev/null @@ -1,38 +0,0 @@ -from dataclasses import dataclass - -import hydra -from hydra.core.config_store import ConfigStore - -from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQualityDriver -from squirrel_datasets_core.preprocessing.save_shards import SaveShardsConfig, save_shards -from squirrel_datasets_core.spark.setup_spark import get_spark - - -@dataclass -class KaggleCastingShardConfig(SaveShardsConfig): - identifier: str = "kaggle_casting_quality" - version: int = 1 - num_shards: int = 100 - output_data_url: str = "kaggle_casting_quality/train" - output_catalog_url: str = "kaggle_casting_quality/train.yaml" - split: str = "train" - parse: bool = False - - -cs = ConfigStore.instance() -cs.store(name="kaggle_casting_config", node=KaggleCastingShardConfig) - - -@hydra.main(config_path=None, config_name="kaggle_casting_config") -def main(cfg: KaggleCastingShardConfig) -> None: - """When running via cli create spark session yourself""" - save_shards( - cfg=cfg, - session=get_spark("kagglecastingquality-preprocessing"), - iter_kwargs=dict(split=cfg.split, parse=cfg.parse), - hooks=[RawKaggleCastingQualityDriver.load_sample], - ) - - -if __name__ == "__main__": - main() diff --git a/src/squirrel_datasets_core/preprocessing/__init__.py b/src/squirrel_datasets_core/preprocessing/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/squirrel_datasets_core/preprocessing/save_shards.py b/src/squirrel_datasets_core/preprocessing/save_shards.py deleted file mode 100644 index ebab1e4..0000000 --- a/src/squirrel_datasets_core/preprocessing/save_shards.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -import os -from dataclasses import dataclass -from functools import partial -from typing import Callable, Dict, Iterable, List, Optional, TYPE_CHECKING - -from pyspark.sql import SparkSession -from squirrel.catalog import Catalog, Source -from squirrel.serialization import MessagepackSerializer -from squirrel.store import SquirrelStore - -if TYPE_CHECKING: - from squirrel.constants import ShardType - - -@dataclass -class SaveShardsConfig: - identifier: str # identifier used in the catalog for the dataset - version: int # dataset version to use - num_shards: int - output_data_url: str # path where the shard will be written - output_catalog_url: str # path where the catalog will be written - num_samples: Optional[int] = None # number of samples to take from the dataset, if None, all samples will be taken - - -def save_iterable_as_shard(shard_it: Iterable[ShardType], output_url: str) -> None: - """Helper to save a shard into a messagepack store using squirrel""" - # only initialize store if really needed - store = None - for shard_id, shard in enumerate(shard_it): - if store is None: - store = SquirrelStore(output_url, serializer=MessagepackSerializer()) - - store.set(key=f"shard_{shard_id}", value=shard) - - -def save_shards( - cfg: SaveShardsConfig, - session: SparkSession, - iter_kwargs: Dict, - hooks: Optional[List[Callable[[Dict], Dict]]] = None, -) -> None: - """Process a source and save it as messagepack. - - Args: - cfg (SaveShardsConfig): Config to use. - session (SparkSession): Spark session. - iter_kwargs (Dict): Keyword arguments passed to :py:meth:`Driver.get_iter`. - hooks (Optional[List[Callable[[Dict], Dict]]], optional): Methods to map to the samples before creating shards. - """ - if hooks is None: - hooks = [] - - # get raw data - cat = Catalog.from_plugins() - d = cat[cfg.identifier][cfg.version] - src_it = d.load.get_iter(**iter_kwargs) - if cfg.num_samples is not None: - src_it = src_it.take(cfg.num_samples) - - # resolve relative paths - out_url = os.path.abspath(cfg.output_data_url) - - # run preprocessing with spark - pipe = session.sparkContext.parallelize(src_it) - for h in hooks: - pipe = pipe.map(h) - pipe = pipe.coalesce(cfg.num_shards) - _ = pipe.foreachPartition(partial(save_iterable_as_shard, output_url=out_url)) - - ncat = Catalog() - # set version via DummyCatalogSource - cat_source = ncat[cfg.identifier] - cat_source[cfg.version + 1] = Source( - driver_name="messagepack", - metadata=d.metadata, - driver_kwargs={"url": out_url}, - ) - ncat.to_file(cfg.output_catalog_url) diff --git a/src/squirrel_datasets_core/spark/__init__.py b/src/squirrel_datasets_core/spark/__init__.py deleted file mode 100644 index a071e7d..0000000 --- a/src/squirrel_datasets_core/spark/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from squirrel_datasets_core.spark.setup_spark import get_spark, stop_spark - -__all__ = ["get_spark", "stop_spark"] diff --git a/src/squirrel_datasets_core/spark/setup_spark.py b/src/squirrel_datasets_core/spark/setup_spark.py deleted file mode 100644 index d9840ac..0000000 --- a/src/squirrel_datasets_core/spark/setup_spark.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import typing as t - -from pyspark import SparkConf -from pyspark.sql import SparkSession - - -def get_spark( - app_name: str, - dataset_conf: t.Dict[str, str] = None, - add_gcs_connector: bool = True, - add_hadoop: bool = False, - mem_size: str = "4g", -) -> SparkSession: - """ - Get a spark session for your operation. - - Args: - app_name: The name of your spark session. - dataset_conf: A dictionary containing configs for your runtime spark configs. Must be configs that can be - altered after the creation of the spark session, otherwise will be complained by spark. A detailed list of - viable options can be found here: https://spark.apache.org/docs/latest/configuration.html. - add_gcs_connector: If true, will add shaded gcs-connector for hadoop3 and other related jar files to spark. - It must be True in order for cloudbuild to work. - add_hadoop: If true, will add hadoop native libs to spark lib path, such that spark can use native hadoop lib - tools to compress and decompress zst files. - mem_size: Memory size for the driver node. - - Returns: - SparkSession - """ - conf = SparkConf() - conf.set("spark.executor.memory", mem_size) - conf.set("spark.jars.repositories", "https://maven.google.com/") # add google maven repo as extra jars lookup link. - # add already installed jar files into path. - _SPARK_HOME = _get_home("SPARK_HOME") - _EXTRA_LIB_PATH = f"{_SPARK_HOME}/jars" - conf.set("spark.driver.extraLibraryPath", _EXTRA_LIB_PATH) - conf.set("spark.executor.extraLibraryPath", _EXTRA_LIB_PATH) - - if add_gcs_connector: - conf = _add_gcs_connector(conf) - - if add_hadoop: - conf = _add_native_hadoop(conf, extra_lib_path=_EXTRA_LIB_PATH) - - assert isinstance(app_name, str), ValueError("`app_name` accept string only.") - spark = SparkSession.builder.appName(app_name).config(conf=conf).getOrCreate() - - if dataset_conf: - for key, value in dataset_conf.items(): - spark.conf.set(key, str(value)) - - return spark - - -def _get_home(env_var: str) -> str: - """Get home directories for e.g. spark, hadoop or other java lib homes.""" - _HOME = os.environ.get(env_var) - assert _HOME is not None, EnvironmentError(f"Could not find '{env_var}' in current OS env.") - return _HOME - - -def _add_gcs_connector(conf: SparkConf) -> SparkConf: - """Add gcs-connector and related dependencies (jar files) to spark configuration.""" - _SPARK_HOME = _get_home("SPARK_HOME") - GCS_CONNECTOR_JARS = { - "spark.jars.packages": "com.google.cloud.bigdataoss:gcs-connector:hadoop3-2.2.2", - "spark.hadoop.fs.AbstractFileSystem.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS", - "spark.hadoop.fs.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem", - "spark.hadoop.google.cloud.auth.service.account.enable": "true", - "spark.jars": f"{_SPARK_HOME}/jars/gcs-connector-hadoop3-latest.jar", - } - for key, val in GCS_CONNECTOR_JARS.items(): - conf.set(key, val) - return conf - - -def _add_native_hadoop(conf: SparkConf, extra_lib_path: str = None) -> SparkConf: - """Link extra hadoop native libs to spark configuration, such that spark can use these libraries to do some specific - IO works. Must be added when one does I/O with zstandard compressed files in spark. (The actual hadoop native libs - are pre-installed in docker image `infrastructure/docker/Dockerfile.spark`. - """ - _HADOOP_HOME = _get_home("HADOOP_HOME") - if extra_lib_path is not None: - extra_lib_path = f"{extra_lib_path}:{_HADOOP_HOME}/lib/native" - else: - extra_lib_path = f"{_HADOOP_HOME}/lib/native" - conf.set("spark.driver.extraLibraryPath", extra_lib_path) - conf.set("spark.executor.extraLibraryPath", extra_lib_path) - return conf - - -def stop_spark(spark: SparkSession) -> None: - """Stop a spark session in a graceful way.""" - spark.sparkContext._gateway.shutdown_callback_server() - spark.stop() - - -if __name__ == "__main__": - spark = get_spark("test") From a10d35fd1c930dcd0d9a9ae7355f8e64dec83ced Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 12:46:39 +0100 Subject: [PATCH 14/19] Add test for Hub driver --- test/__init__.py | 0 test/conftest.py | 17 ++++++++++++++++ test/test_datasets/test_hub.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 test/__init__.py create mode 100644 test/conftest.py create mode 100644 test/test_datasets/test_hub.py diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..0a94d70 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,17 @@ +"""This module defines specific fixtures for unit tests. Shared fixtures are defined in shared_fixtures.py. + +################################### +Please do not import from this file. +################################### + +Not importing from conftest is a best practice described in the note here: +https://pytest.org/en/6.2.x/writing_plugins.html#conftest-py-local-per-directory-plugins +""" + +import pytest + + +@pytest.fixture() +def unit_test_fixture() -> str: + """Fixture that is only used in the unit test scope.""" + return "unit test string" diff --git a/test/test_datasets/test_hub.py b/test/test_datasets/test_hub.py new file mode 100644 index 0000000..79a0536 --- /dev/null +++ b/test/test_datasets/test_hub.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import numpy as np + +from squirrel_datasets_core.driver.hub import HubDriver + + +def test_hub_driver(tmpdir: Path) -> None: + """Test writing and reading using hub driver.""" + driver = HubDriver(str(tmpdir)) + + # create dummy hub dataset + SAMPLES = 10 + SPLIT = "train" + with driver.get_hub(overwrite=True) as ds: + ds.create_group(SPLIT) + ds[SPLIT].create_tensor("image", htype="image", sample_compression="jpg") + for _ in range(SAMPLES): + ds[SPLIT].image.append(np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8)) + + # get iterstream from hub dataset + it = driver.get_iter(SPLIT) + i = 0 + for s in it: + assert s["image"].numpy().shape == (32, 32, 3) + i += 1 + assert i == SAMPLES + + +def test_hub_driver_hosted() -> None: + """Test reading from a hosted hub dataset.""" + driver = HubDriver("hub://activeloop/coco-train") + data = driver.get_iter().take(5).collect() + assert len(data) == 5 + for k in ("images", "images_meta", "masks"): + assert k in data[0] From 3ee958abf07edd5e1fb0fa4048e36596d3b3517f Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 16:09:03 +0100 Subject: [PATCH 15/19] Extend documentation by Contribute section --- docs/source/add_dataset.rst | 57 +++++++++++++++++++ docs/source/index.rst | 17 +++++- .../kaggle_casting_quality/README.rst | 8 ++- src/test.ipynb | 42 -------------- 4 files changed, 76 insertions(+), 48 deletions(-) create mode 100644 docs/source/add_dataset.rst delete mode 100644 src/test.ipynb diff --git a/docs/source/add_dataset.rst b/docs/source/add_dataset.rst new file mode 100644 index 0000000..9e579dc --- /dev/null +++ b/docs/source/add_dataset.rst @@ -0,0 +1,57 @@ +Contribute to Squirrel Datasets +=============================== + +Squirrel-datasets supports you with two tasks: + +* Preprocessing data and registering the preprocessed dataset in the data mesh +* Loading data from the data mesh + +Preprocessing +------------- +For the first task, i.e. preprocessing, we recommend using `Apache Spark`_. The scenario is that quite often you would +like to work with data stored in Google Cloud Storage and finish your batch processing job on a kubernetes cluser. We use +`PySpark`_ for defining the preprocessing logic in python. + +Data Loading +------------ +For the second task, i.e. data loading, we use the high level API from :code:`squirrel`. The corresponding data loading logic +is defined through a :py:class:`Driver` class. + +Add a New Dataset +------------------ +After having understood the two above discussed main tasks and how we handle them, here is how it looks like when you +want to add a new dataset into :code:`squirrel-datasets`: define your preprocessing logic; define your loading logic; +register the dataset into a catalog plugin. + +#. Define your preprocessing logic. + + - Create a new directory under :code:`squirrel_datasets_core/datasets` named after your dataset, e.g. "example_dataset". + Write your preprocessing scripts under a new ``preprocessing.py`` file in it. + +#. Define your loading logic. + + - After the preprocessing step, you want to make sure your preprocessed dataset is valid and readable. In that case, + you need to define the loading logic. The driver defines how the dataset is read from the serialized file into your memory. + + - In :code:`squirrel` there are already many built-in drivers for reading all kinds of datasets. There are + :py:class:`CSVDataloader`, :py:class:`JSONLoader`, :py:class:`MessagePackDataLoader`, :py:class:`RecordLoader` + and many others. For details, please refer to `squirrel.driver`_. + + - Select a suitable driver if one of them is applicable to your dataset's format and compression method. + + - If there is no driver suitable for your dataset, then you need to define a custom driver. The custom driver should + have the same interface as :py:class:`squirrel.driver.IterDriver`. We recommend that you subclass from + this class, then add the loading logic inside. This class should be saved under + :code:`squirrel_datasets_core/datasets/example_dataset/driver.py` + + .. note:: + + This is not always the case that data loading occurs after the preprocessing steps. For image datasets, spark is + not always the right tool to do it. In that case, you may want to load and process the data without it, and you + need to define the loading logic for your raw data. In that case, you may swap the above steps or use them more + flexibly. See `squirrel_datasets_core.datasets.imagenet`_ for an example. + +.. _Apache Spark: https://spark.apache.org/docs/latest/ +.. _PySpark: https://spark.apache.org/docs/latest/api/python/ +.. _squirrel.driver: https://squirrel.readthedocs.io/ +.. _squirrel_datasets_core.datasets.imagenet: https://squirrel.readthedocs.io/ \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index f601daa..596dfca 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,12 +1,23 @@ Squirrel Datasets Documentation =============================== +`Squirrel Datasets `_ is an extension of the Squirrel platform for data transform, access, and discovery. +It includes common drivers for public datasets, which are ready to use along with data preprocessing logic. +Visit the official `Squirrel Documentation `_ on more information how to use Squirrel. +Find out more about Merantix Labs on our `Website `_. -Datasets --------- +Available Datasets +------------------ .. toctree:: :maxdepth: 2 dataset_links/kaggle_casting_quality - dataset_links/imagenet \ No newline at end of file + dataset_links/imagenet + +Contribute +------------------ + +.. toctree:: + + add_dataset diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst index 0aa7be9..21a27bf 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/README.rst @@ -60,9 +60,11 @@ A sample from the training set is provided below: .. code-block:: - {'url': '{PATH_TO_DATA}/casting_data/casting_data/test/ok_front/cast_ok_0_9996.jpeg', - 'label': 1, - 'image': array(...)} + { + 'url': '{PATH_TO_DATA}/casting_data/casting_data/test/ok_front/cast_ok_0_9996.jpeg', + 'label': 1, + 'image': array(...) + } Dataset Schema ************** diff --git a/src/test.ipynb b/src/test.ipynb deleted file mode 100644 index 6f791a2..0000000 --- a/src/test.ipynb +++ /dev/null @@ -1,42 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQuality\n", - "iter = RawKaggleCastingQuality(\"archive/casting_data/casting_data\").get_iter(\"test\")\n", - "\n", - "for i in iter:\n", - " print(i[\"label\"])" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "338cfae0a1873dbc3e6a1266258c61ca663f4d905ce797b41b2835118552379a" - }, - "kernelspec": { - "display_name": "Python 3.8.12 64-bit ('sq-env': conda)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From d1e40ddf74d529319f046121a74c09b5bc19a59f Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Mon, 21 Feb 2022 16:11:02 +0100 Subject: [PATCH 16/19] Parse attributes in dataset cards + Lint --- docs/source/conf.py | 32 ++++---- requirements.in | 1 + setup.py | 13 +-- src/squirrel_datasets_core/__init__.py | 2 +- .../datasets/imagenet/__init__.py | 4 +- .../kaggle_casting_quality/__init__.py | 4 +- .../parse_dataset_cards.py | 81 +++++++++++++++++++ 7 files changed, 114 insertions(+), 23 deletions(-) create mode 100644 src/squirrel_datasets_core/parse_dataset_cards.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1e1928c..d757caf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,30 +3,34 @@ # -- Project information -project = 'Squirrel Datasets' -copyright = f'{datetime.datetime.now().year}, Merantix Labs GmbH' -author = 'Merantix Labs GmbH' +project = "Squirrel Datasets" +copyright = f"{datetime.datetime.now().year}, Merantix Labs GmbH" +author = "Merantix Labs GmbH" # -- General configuration extensions = [ - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', + "sphinx.ext.duration", + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "sphinx.ext.autosectionlabel", ] intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), + "python": ("https://docs.python.org/3/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), } -intersphinx_disabled_domains = ['std'] +intersphinx_disabled_domains = ["std"] -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # -- Options for EPUB output -epub_show_urls = 'footnote' \ No newline at end of file +epub_show_urls = "footnote" + +autoclass_content = "both" diff --git a/requirements.in b/requirements.in index 7856c0d..c79bd5a 100644 --- a/requirements.in +++ b/requirements.in @@ -7,3 +7,4 @@ scipy pyspark==3.2.0 hydra-core>=1.1.0 hub +docutils diff --git a/setup.py b/setup.py index 9133040..ac5f8d0 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,9 @@ import re import sys -from setuptools import find_packages, setup +from setuptools import setup -SOURCE_DIR = "squirrel_datasets_core" +SOURCE_DIR = "src/squirrel_datasets_core" # Read package information from other files so that just one version has to be maintained. _version_re = re.compile(r"__version__\s+=\s+(.*)") @@ -82,12 +82,12 @@ def parse_req(spec: str) -> str: else: SCRIPTS = [] -ENTRY_POINTS = {"squirrel": ["datasets_core = squirrel_datasets_core.squirrel_plugin"]} +ENTRY_POINTS = {"squirrel": ["squirrel_datasets_core = squirrel_datasets_core.squirrel_plugin"]} # Setup package using PIP if __name__ == "__main__": setup( - name=SOURCE_DIR, + name="squirrel_datasets_core", version=version, python_requires=">=3.8.0", description="Squirrel public datasets collection", @@ -95,8 +95,8 @@ def parse_req(spec: str) -> str: author="Merantix Labs GmbH", license="", # Needed to make jinja work and not get linting errors in the rendered file - package_dir=dict([(SOURCE_DIR, SOURCE_DIR)]), - packages=find_packages(where="./src/"), + package_dir={"": "src"}, + packages=["squirrel_datasets_core"], scripts=SCRIPTS, include_package_data=True, install_requires=requirements, @@ -104,4 +104,5 @@ def parse_req(spec: str) -> str: extras_require=extras_require, entry_points=ENTRY_POINTS, classifiers=["Public"], + package_data={"": ["*.rst"]}, ) diff --git a/src/squirrel_datasets_core/__init__.py b/src/squirrel_datasets_core/__init__.py index b3c06d4..f102a9c 100644 --- a/src/squirrel_datasets_core/__init__.py +++ b/src/squirrel_datasets_core/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" \ No newline at end of file +__version__ = "0.0.1" diff --git a/src/squirrel_datasets_core/datasets/imagenet/__init__.py b/src/squirrel_datasets_core/datasets/imagenet/__init__.py index a0c8b5d..873708e 100644 --- a/src/squirrel_datasets_core/datasets/imagenet/__init__.py +++ b/src/squirrel_datasets_core/datasets/imagenet/__init__.py @@ -1,4 +1,6 @@ from squirrel_datasets_core.datasets.imagenet.driver import RawImageNetDriver +from squirrel_datasets_core.parse_dataset_cards import parse_readme -__all__ = ["RawImageNetDriver", "DRIVERS"] +__all__ = ["RawImageNetDriver", "DRIVERS", "DATASET_ATTRIBUTES"] DRIVERS = [RawImageNetDriver] +DATASET_ATTRIBUTES = parse_readme(__name__) diff --git a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py index 38001b6..91edd10 100644 --- a/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py +++ b/src/squirrel_datasets_core/datasets/kaggle_casting_quality/__init__.py @@ -1,4 +1,6 @@ from squirrel_datasets_core.datasets.kaggle_casting_quality.driver import RawKaggleCastingQualityDriver +from squirrel_datasets_core.parse_dataset_cards import parse_readme -__all__ = ["RawKaggleCastingQualityDriver", "DRIVERS"] +__all__ = ["RawKaggleCastingQualityDriver", "DRIVERS", "DATASET_ATTRIBUTES"] DRIVERS = [RawKaggleCastingQualityDriver] +DATASET_ATTRIBUTES = parse_readme(__name__) diff --git a/src/squirrel_datasets_core/parse_dataset_cards.py b/src/squirrel_datasets_core/parse_dataset_cards.py new file mode 100644 index 0000000..dd9c827 --- /dev/null +++ b/src/squirrel_datasets_core/parse_dataset_cards.py @@ -0,0 +1,81 @@ +import docutils.nodes +import docutils.parsers.rst +import docutils.utils +import docutils.frontend +import pkg_resources + +from typing import Dict + + +def parse_rst(text: str) -> docutils.nodes.document: + """Parse RST file. + + Args: + text (str): RST file as string. + + Returns: + docutils.nodes.document: Parsed document. + """ + parser = docutils.parsers.rst.Parser() + components = (docutils.parsers.rst.Parser,) + settings = docutils.frontend.OptionParser(components=components).get_default_values() + document = docutils.utils.new_document("", settings=settings) + parser.parse(text, document) + return document + + +class MachineAttributesVisitor(docutils.nodes.NodeVisitor): + def __init__(self, doc: docutils.nodes.document) -> None: + """Init Reader to parse RST files for machine readable attributes. + + Args: + doc (docutils.nodes.document): Document to parse through. + """ + super().__init__(doc) + self.tables = 0 + self._machine_readable_attributes = {} + + def visit_table(self, node: docutils.nodes.table) -> None: + """Called for table nodes.""" + self.tables += 1 + + def visit_row(self, node: docutils.nodes.tbody) -> None: + """Called for row nodes.""" + if ( + self.tables > 1 + or len(node.children) < 2 + or len(node.children[0].children) < 1 + or len(node.children[1].children) < 1 + ): + return + + self._machine_readable_attributes[str(node.children[0][0][0])] = str(node.children[1][0][0]) + + def unknown_visit(self, node: docutils.nodes.Node) -> None: + """Called for other nodes.""" + pass + + def machine_readable_attributes(self) -> Dict[str, str]: + """Machine readable summary of attributes. + + Returns: + Dict[str, str]: Parsed summary. + """ + return self._machine_readable_attributes + + +def parse_readme(name: str) -> Dict[str, str]: + """Parse readme in the current directory and extract attributes. + + Args: + name (str): path to dataset directory. + + Returns: + _type_: Machine readable summary of attributes. + """ + stream = pkg_resources.resource_stream(name, "README.rst") + doc = parse_rst(stream.read().decode("utf-8")) + visitor = MachineAttributesVisitor(doc) + doc.walk(visitor) + + return visitor.machine_readable_attributes() From 066f460b54e9e353721a168c1eea0415abb6fd52 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 24 Feb 2022 10:04:01 +0100 Subject: [PATCH 17/19] Change constant to upper case --- src/squirrel_datasets_core/datasets/camvid/driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squirrel_datasets_core/datasets/camvid/driver.py b/src/squirrel_datasets_core/datasets/camvid/driver.py index 8c34bab..9ac3da7 100644 --- a/src/squirrel_datasets_core/datasets/camvid/driver.py +++ b/src/squirrel_datasets_core/datasets/camvid/driver.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from squirrel.iterstream import Composable -_classes = [ +_CLASSES = [ "sky", "building", "pole", @@ -27,7 +27,7 @@ "bicyclist", "unlabelled", ] -CAMVID_NAME_TO_LABEL_ID = dict(zip(_classes, range(12))) +CAMVID_NAME_TO_LABEL_ID = dict(zip(_CLASSES, range(12))) CAMVID_LABEL_ID_TO_NAME = {idx: cls for cls, idx in CAMVID_NAME_TO_LABEL_ID.items()} From 62183bdf385753a6b70290865125a117e81d489d Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 24 Feb 2022 11:13:05 +0100 Subject: [PATCH 18/19] Update Readme.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f0a755f..e5614c9 100644 --- a/README.rst +++ b/README.rst @@ -7,4 +7,4 @@ your own datasets with the tools we provide here. For preprocessing, we currently support Spark as the main tool to carry out the task. -Please see our [documentation](https://squirrel-datasets-core.readthedocs.io) for further details. \ No newline at end of file +Please see our `documentation `_ for further details. \ No newline at end of file From cdba8bb64a9f9a482709e1d9637865bff4e9001e Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Thu, 24 Feb 2022 18:07:54 +0100 Subject: [PATCH 19/19] Add intersphinx reference to squirrel docs --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index d757caf..eb24832 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "sphinx": ("https://www.sphinx-doc.org/en/master/", None), + "squirrel": ("https://squirrel.readthedocs.io/", None), } intersphinx_disabled_domains = ["std"]