Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified dataset to load from a local dir #5

Open
robmarkcole opened this issue May 25, 2021 · 0 comments
Open

Modified dataset to load from a local dir #5

robmarkcole opened this issue May 25, 2021 · 0 comments

Comments

@robmarkcole
Copy link

Hi @isaaccorley
I modified the dataset to load from a local dir (in my case a mounted google drive) - might be one to add to the wiki if not to add to the codebase

class BaseDataset(torch.utils.data.Dataset):
    """Base Super Resolution Dataset Class
    """
    color_space: str = "RGB"
    lr_transform: T.Compose = None
    hr_transform: T.Compose = None

    def get_lr_transforms(self):
        """Returns HR to LR image transformations
        """
        return Compose([
            Resize(size=(
                    self.image_size//self.scale_factor,
                    self.image_size//self.scale_factor
                ),
                interpolation=Image.BICUBIC
            ),
            ToTensor(),
        ])

    def get_hr_transforms(self):
        """Returns HR image transformations
        """
        return Compose([
            Resize((self.image_size, self.image_size), Image.BICUBIC),
            ToTensor(),
        ])

    def get_files(self, data_dir: str) -> List[str]:
        """Returns  a list of valid image files in a directory
        Parameters
        ----------
        root_dir : str
            Path to directory of images.
        Returns
        -------
        List[str]
            List of valid images in `root_dir` directory.
        """
        return glob.glob(data_dir + '*.jpg')

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns  a tuple of and lr and hr torch tensors
        Parameters
        ----------
        idx : int
            Index value to index the list of images
        Returns
        -------
        lr: torch.Tensor
            Low Resolution transformed indexed image.
        hr: torch.Tensor
            High Resolution transformed indexed image.
        """
        lr = self.load_img(self.file_names[idx])
        hr = lr.copy()
        if self.lr_transform:
            lr = self.lr_transform(lr)
        if self.hr_transform:
            hr = self.hr_transform(hr)

        return lr, hr

    def __len__(self) -> int:
        """Return number of images in dataset
        Returns
        -------
        int
            Number of images in dataset file_names list
        """
        return len(self.file_names)


    def load_img(self, file_path: str) -> Image.Image:
        """Returns a PIL Image of the image located at `file_path`
        Parameters
        ----------
        file_path : str
            Path to image file to be loaded
        Returns
        -------
        PIL.Image.Image
            Loaded image as PIL Image
        """
        return Image.open(file_path).convert(self.color_space)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant