From 875274b841db9128e65d205f48b177490fdd496d Mon Sep 17 00:00:00 2001 From: Alexander Koenig <50868646+axkoenig@users.noreply.github.com> Date: Mon, 19 Dec 2022 14:33:18 +0100 Subject: [PATCH] Fix Driver imports (#74) * Fix driver imports * Lint * Version bump --- README.md | 2 +- docs/code/all_drivers.py | 26 +++++++++---------- docs/code/real_scenario.py | 6 ++--- src/squirrel_datasets_core/__init__.py | 2 +- src/squirrel_datasets_core/driver/__init__.py | 6 ----- src/squirrel_datasets_core/squirrel_plugin.py | 14 +++++++--- 6 files changed, 27 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 9fd26ec..2dd596a 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ pip install "squirrel-datasets-core[all]" A great feature of squirrel-datasets-core is that you can easily load data from common databases such as Huggingface, Activeloop Deeplake, Hub and Torchvision with one line of code. And you get to enjoy all of Squirrel’s benefits for free! Check out the [documentation](https://squirrel-datasets-core.readthedocs.io/en/latest/driver_integration.html) on how to interface with these libraries. ```python -from squirrel_datasets_core.driver import HuggingfaceDriver +from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver it = HuggingfaceDriver("cifar100").get_iter("train").filter(custom_filter).map(custom_augmentation) diff --git a/docs/code/all_drivers.py b/docs/code/all_drivers.py index 0ea881a..3d42a0a 100644 --- a/docs/code/all_drivers.py +++ b/docs/code/all_drivers.py @@ -1,17 +1,7 @@ -from squirrel_datasets_core.driver import ( - HuggingfaceDriver, - DeeplakeDriver, - HubDriver, - TorchvisionDriver, -) - -HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join() -# prints -# { -# "img": , -# "fine_label": 19, -# "coarse_label": 11, -# } +from squirrel_datasets_core.driver.deeplake import DeeplakeDriver +from squirrel_datasets_core.driver.hub import HubDriver +from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver +from squirrel_datasets_core.driver.torchvision import TorchvisionDriver DeeplakeDriver("hub://activeloop/cifar100-train").get_iter().take(1).map(print).join() # prints @@ -28,6 +18,14 @@ # "coarse_labels": Tensor(key="coarse_labels", index=Index([0])), # } +HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join() +# prints +# { +# "img": , +# "fine_label": 19, +# "coarse_label": 11, +# } + TorchvisionDriver("cifar100", download=True).get_iter().take(1).map(print).join() # prints # ( diff --git a/docs/code/real_scenario.py b/docs/code/real_scenario.py index 169f35b..e6152ff 100644 --- a/docs/code/real_scenario.py +++ b/docs/code/real_scenario.py @@ -2,9 +2,9 @@ import torch import torchvision.transforms.functional as F -from squirrel_datasets_core.driver import HuggingfaceDriver +from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate as torch_default_collate +from torch.utils.data._utils.collate import default_collate BATCH_SIZE = 16 @@ -31,7 +31,7 @@ def hf_pre_proc(item: t.Dict[str, t.Any]) -> t.Dict[str, torch.Tensor]: .get_iter("train") .split_by_worker_pytorch() .map(hf_pre_proc) - .batched(BATCH_SIZE, torch_default_collate) + .batched(BATCH_SIZE, default_collate) .to_torch_iterable() ) train_loader = DataLoader(train_driver, batch_size=None, num_workers=2) diff --git a/src/squirrel_datasets_core/__init__.py b/src/squirrel_datasets_core/__init__.py index 3cb7d95..d3ec452 100644 --- a/src/squirrel_datasets_core/__init__.py +++ b/src/squirrel_datasets_core/__init__.py @@ -1 +1 @@ -__version__ = "0.1.13" +__version__ = "0.2.0" diff --git a/src/squirrel_datasets_core/driver/__init__.py b/src/squirrel_datasets_core/driver/__init__.py index bc045cf..e69de29 100644 --- a/src/squirrel_datasets_core/driver/__init__.py +++ b/src/squirrel_datasets_core/driver/__init__.py @@ -1,6 +0,0 @@ -from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver -from squirrel_datasets_core.driver.torchvision import TorchvisionDriver -from squirrel_datasets_core.driver.hub import HubDriver -from squirrel_datasets_core.driver.deeplake import DeeplakeDriver - -__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver", "DeeplakeDriver"] diff --git a/src/squirrel_datasets_core/squirrel_plugin.py b/src/squirrel_datasets_core/squirrel_plugin.py index 4bca4c4..0b46dc7 100644 --- a/src/squirrel_datasets_core/squirrel_plugin.py +++ b/src/squirrel_datasets_core/squirrel_plugin.py @@ -1,3 +1,4 @@ +import contextlib import importlib import pkgutil from typing import List, Tuple, Type @@ -14,6 +15,13 @@ def get_hub_driver() -> Driver: return HubDriver +def get_deeplake_driver() -> Driver: + """Imports and returns the deeplake driver class""" + from squirrel_datasets_core.driver.deeplake import DeeplakeDriver + + return DeeplakeDriver + + def get_huggingface_driver() -> Driver: """Imports and returns the huggingface driver class""" from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver @@ -36,6 +44,7 @@ def squirrel_drivers() -> List[Type[Driver]]: drivers = [] add_drivers = { "hub": get_hub_driver, + "deeplake": get_deeplake_driver, "huggingface": get_huggingface_driver, "torchvision": get_torchvision_driver, } @@ -70,10 +79,7 @@ def squirrel_sources() -> List[Tuple[CatalogKey, Source]]: import squirrel_datasets_core.datasets as ds for m in pkgutil.iter_modules(ds.__path__): - try: + with contextlib.suppress(AttributeError): d = importlib.import_module(f"{ds.__package__}.{m.name}").SOURCES datasets += d - except AttributeError: - pass - return datasets