Skip to content

Commit

Permalink
Update torchvision import in dataset wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Hananel-Hazan committed Mar 29, 2024
1 parent ff705ff commit 4f22df2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bindsnet/datasets/torchvision_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional

import torch
import torchvision
from torchvision import datasets as torchDB

from bindsnet.encoding import Encoder, NullEncoder

Expand All @@ -13,7 +13,7 @@ def create_torchvision_dataset_wrapper(ds_type):
``__getitem__``. This applies to all of the datasets inside of ``torchvision``.
"""
if type(ds_type) == str:
ds_type = getattr(torchvision.datasets, ds_type)
ds_type = getattr(torchDB, ds_type)

class TorchvisionDatasetWrapper(ds_type):
__doc__ = (
Expand Down

0 comments on commit 4f22df2

Please sign in to comment.