Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataloading #955

Merged
merged 4 commits into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 41 additions & 67 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
import warnings
from abc import ABC

import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand Down Expand Up @@ -90,36 +78,19 @@ def call_prepare_data(self, model):
model.prepare_data()

def auto_add_sampler(self, dataloader, train):
# do nothing when user gives a sampler
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}

if train:
if self.use_ddp or self.use_ddp2:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}

elif self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal()
)
dl_args['shuffle'] = False
else:
sampler = RandomSampler(dataloader.dataset)

# on not train
else:
if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
Expand All @@ -128,12 +99,16 @@ def auto_add_sampler(self, dataloader, train):
)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
if train:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)

dl_args['sampler'] = sampler
dl_args['sampler'] = sampler

new_dataloader = DataLoader(**dl_args)
return new_dataloader
dataloader = DataLoader(**dl_args)
return dataloader

def reset_train_dataloader(self, model):
"""
Expand All @@ -148,12 +123,12 @@ def reset_train_dataloader(self, model):
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

# determine number of training batches
if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset):
self._percent_range_check('train_percent_check')

if self.is_infinite_dataloader(self.train_dataloader):
self.num_training_batches = float('inf')
else:
self._percent_range_check('train_percent_check')

# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)

Expand All @@ -168,27 +143,26 @@ def reset_train_dataloader(self, model):
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)

self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
)
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for `train_dataloader`,
`Trainer(val_check_interval)` must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException(m)

def is_iterable_dataloader(self, dataloader):
return (
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
)
def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True

def reset_val_dataloader(self, model):
"""
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,19 +1114,14 @@ def run_pretrain_routine(self, model: LightningModule):
self.run_evaluation(test_mode=True)
return

# load the dataloaders
self.reset_train_dataloader(ref_model)
self.reset_val_dataloader(ref_model)

# check if we should run validation during training
self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step')
self.disable_validation = self.disable_validation and not self.fast_dev_run
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run

# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if not self.disable_validation and self.num_sanity_val_steps > 0:
self.reset_val_dataloader(ref_model)
# init progress bars for validation sanity check
pbar = tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.val_dataloaders),
Expand Down
26 changes: 17 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def is_function_implemented(self, m):
pass

@abstractmethod
def is_iterable_dataloader(self, dataloader):
def is_infinite_dataloader(self, dataloader):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be considered infinite? Normally, IterableDatasets are finite, we just don't know how long they are the first epoch. Another way of doing this would be to set it to infinite (or -1, or whatever placeholder value works best) and keep counting how many steps we did the first epoch. Once we start the second epoch, we can print a bar and all since we know the length will not change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thinking about it that should totally just be has_len or similar. No problem with the idea of keeping a record of the number of steps - although I would probably opt for that to be done in a seperate PR (lile when we add IterableDataset support for val and test) - but I'm happy to add that now if desired

# this is just empty shell for code from other class
pass

Expand Down Expand Up @@ -325,6 +325,11 @@ def reset_train_dataloader(self, model):
# this is just empty shell for code from other class
pass

@abstractmethod
def reset_val_dataloader(self, model):
# this is just empty shell for code from other class
pass

@abstractmethod
def has_arg(self, f_name, arg_name):
# this is just empty shell for code from other class
Expand All @@ -334,11 +339,17 @@ def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)

# get model
model = self.get_model()

# load data
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)

# Train begin callbacks
model.on_train_start()
self.on_train_start()

# get model
model = self.get_model()
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
Expand All @@ -347,9 +358,6 @@ def train(self):
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
Expand All @@ -370,8 +378,8 @@ def train(self):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_dataloader(self.train_dataloader):
# for iterable train loader, the progress bar never ends
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
Expand All @@ -380,7 +388,7 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
Expand Down
1 change: 1 addition & 0 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _dataloader(self, train):
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True
)

return loader
Expand Down
54 changes: 54 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,60 @@ def test_model_freeze_unfreeze():
model.unfreeze()


def test_inf_train_dataloader(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
tutils.reset_seed()

class CurrentTestModel(LightningTestModel):
def train_dataloader(self):
dataloader = self._dataloader(train=True)

class CustomInfDataLoader:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count >= 5:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)

return CustomInfDataLoader(dataloader)

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=50,
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
tutils.reset_seed()
Expand Down