Skip to content

Commit

Permalink
Fix Driver imports (#74)
Browse files Browse the repository at this point in the history
* Fix driver imports

* Lint

* Version bump
  • Loading branch information
axkoenig committed Dec 19, 2022
1 parent 7dd7337 commit 875274b
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pip install "squirrel-datasets-core[all]"

A great feature of squirrel-datasets-core is that you can easily load data from common databases such as Huggingface, Activeloop Deeplake, Hub and Torchvision with one line of code. And you get to enjoy all of Squirrel’s benefits for free! Check out the [documentation](https://squirrel-datasets-core.readthedocs.io/en/latest/driver_integration.html) on how to interface with these libraries.
```python
from squirrel_datasets_core.driver import HuggingfaceDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver

it = HuggingfaceDriver("cifar100").get_iter("train").filter(custom_filter).map(custom_augmentation)

Expand Down
26 changes: 12 additions & 14 deletions docs/code/all_drivers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
from squirrel_datasets_core.driver import (
HuggingfaceDriver,
DeeplakeDriver,
HubDriver,
TorchvisionDriver,
)

HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join()
# prints
# {
# "img": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x1424D6310>,
# "fine_label": 19,
# "coarse_label": 11,
# }
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver
from squirrel_datasets_core.driver.hub import HubDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from squirrel_datasets_core.driver.torchvision import TorchvisionDriver

DeeplakeDriver("hub://activeloop/cifar100-train").get_iter().take(1).map(print).join()
# prints
Expand All @@ -28,6 +18,14 @@
# "coarse_labels": Tensor(key="coarse_labels", index=Index([0])),
# }

HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join()
# prints
# {
# "img": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x1424D6310>,
# "fine_label": 19,
# "coarse_label": 11,
# }

TorchvisionDriver("cifar100", download=True).get_iter().take(1).map(print).join()
# prints
# (
Expand Down
6 changes: 3 additions & 3 deletions docs/code/real_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
import torchvision.transforms.functional as F
from squirrel_datasets_core.driver import HuggingfaceDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate as torch_default_collate
from torch.utils.data._utils.collate import default_collate

BATCH_SIZE = 16

Expand All @@ -31,7 +31,7 @@ def hf_pre_proc(item: t.Dict[str, t.Any]) -> t.Dict[str, torch.Tensor]:
.get_iter("train")
.split_by_worker_pytorch()
.map(hf_pre_proc)
.batched(BATCH_SIZE, torch_default_collate)
.batched(BATCH_SIZE, default_collate)
.to_torch_iterable()
)
train_loader = DataLoader(train_driver, batch_size=None, num_workers=2)
Expand Down
2 changes: 1 addition & 1 deletion src/squirrel_datasets_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.13"
__version__ = "0.2.0"
6 changes: 0 additions & 6 deletions src/squirrel_datasets_core/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from squirrel_datasets_core.driver.torchvision import TorchvisionDriver
from squirrel_datasets_core.driver.hub import HubDriver
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver

__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver", "DeeplakeDriver"]
14 changes: 10 additions & 4 deletions src/squirrel_datasets_core/squirrel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import importlib
import pkgutil
from typing import List, Tuple, Type
Expand All @@ -14,6 +15,13 @@ def get_hub_driver() -> Driver:
return HubDriver


def get_deeplake_driver() -> Driver:
"""Imports and returns the deeplake driver class"""
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver

return DeeplakeDriver


def get_huggingface_driver() -> Driver:
"""Imports and returns the huggingface driver class"""
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
Expand All @@ -36,6 +44,7 @@ def squirrel_drivers() -> List[Type[Driver]]:
drivers = []
add_drivers = {
"hub": get_hub_driver,
"deeplake": get_deeplake_driver,
"huggingface": get_huggingface_driver,
"torchvision": get_torchvision_driver,
}
Expand Down Expand Up @@ -70,10 +79,7 @@ def squirrel_sources() -> List[Tuple[CatalogKey, Source]]:
import squirrel_datasets_core.datasets as ds

for m in pkgutil.iter_modules(ds.__path__):
try:
with contextlib.suppress(AttributeError):
d = importlib.import_module(f"{ds.__package__}.{m.name}").SOURCES
datasets += d
except AttributeError:
pass

return datasets

0 comments on commit 875274b

Please sign in to comment.