Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Driver imports #74

Merged
merged 3 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pip install "squirrel-datasets-core[all]"

A great feature of squirrel-datasets-core is that you can easily load data from common databases such as Huggingface, Activeloop Deeplake, Hub and Torchvision with one line of code. And you get to enjoy all of Squirrel’s benefits for free! Check out the [documentation](https://squirrel-datasets-core.readthedocs.io/en/latest/driver_integration.html) on how to interface with these libraries.
```python
from squirrel_datasets_core.driver import HuggingfaceDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver

it = HuggingfaceDriver("cifar100").get_iter("train").filter(custom_filter).map(custom_augmentation)

Expand Down
26 changes: 12 additions & 14 deletions docs/code/all_drivers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
from squirrel_datasets_core.driver import (
HuggingfaceDriver,
DeeplakeDriver,
HubDriver,
TorchvisionDriver,
)

HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join()
# prints
# {
# "img": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x1424D6310>,
# "fine_label": 19,
# "coarse_label": 11,
# }
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver
from squirrel_datasets_core.driver.hub import HubDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from squirrel_datasets_core.driver.torchvision import TorchvisionDriver

DeeplakeDriver("hub://activeloop/cifar100-train").get_iter().take(1).map(print).join()
# prints
Expand All @@ -28,6 +18,14 @@
# "coarse_labels": Tensor(key="coarse_labels", index=Index([0])),
# }

HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join()
# prints
# {
# "img": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x1424D6310>,
# "fine_label": 19,
# "coarse_label": 11,
# }

TorchvisionDriver("cifar100", download=True).get_iter().take(1).map(print).join()
# prints
# (
Expand Down
6 changes: 3 additions & 3 deletions docs/code/real_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
import torchvision.transforms.functional as F
from squirrel_datasets_core.driver import HuggingfaceDriver
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate as torch_default_collate
from torch.utils.data._utils.collate import default_collate

BATCH_SIZE = 16

Expand All @@ -31,7 +31,7 @@ def hf_pre_proc(item: t.Dict[str, t.Any]) -> t.Dict[str, torch.Tensor]:
.get_iter("train")
.split_by_worker_pytorch()
.map(hf_pre_proc)
.batched(BATCH_SIZE, torch_default_collate)
.batched(BATCH_SIZE, default_collate)
.to_torch_iterable()
)
train_loader = DataLoader(train_driver, batch_size=None, num_workers=2)
Expand Down
2 changes: 1 addition & 1 deletion src/squirrel_datasets_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.13"
__version__ = "0.2.0"
6 changes: 0 additions & 6 deletions src/squirrel_datasets_core/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
from squirrel_datasets_core.driver.torchvision import TorchvisionDriver
from squirrel_datasets_core.driver.hub import HubDriver
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver

__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver", "DeeplakeDriver"]
14 changes: 10 additions & 4 deletions src/squirrel_datasets_core/squirrel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import importlib
import pkgutil
from typing import List, Tuple, Type
Expand All @@ -14,6 +15,13 @@ def get_hub_driver() -> Driver:
return HubDriver


def get_deeplake_driver() -> Driver:
"""Imports and returns the deeplake driver class"""
from squirrel_datasets_core.driver.deeplake import DeeplakeDriver

return DeeplakeDriver


def get_huggingface_driver() -> Driver:
"""Imports and returns the huggingface driver class"""
from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver
Expand All @@ -36,6 +44,7 @@ def squirrel_drivers() -> List[Type[Driver]]:
drivers = []
add_drivers = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also wrap this in a try/except block like below and catch any import errors? This would prevent squirrel from crashing when datasets is installed and some requirements are missing

Copy link
Contributor Author

@axkoenig axkoenig Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you referring to line 42?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I got confused, please ignore this comment :)

"hub": get_hub_driver,
"deeplake": get_deeplake_driver,
"huggingface": get_huggingface_driver,
"torchvision": get_torchvision_driver,
}
Expand Down Expand Up @@ -70,10 +79,7 @@ def squirrel_sources() -> List[Tuple[CatalogKey, Source]]:
import squirrel_datasets_core.datasets as ds

for m in pkgutil.iter_modules(ds.__path__):
try:
with contextlib.suppress(AttributeError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my IDE suggested this change, I think it's much cleaner that way. https://docs.python.org/3/library/contextlib.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

d = importlib.import_module(f"{ds.__package__}.{m.name}").SOURCES
datasets += d
except AttributeError:
pass

return datasets