Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a streaming dataloader with an unbalanced dataset yields unexpected batch sizes. #199

Open
esivonxay-cognitiv opened this issue Jun 29, 2024 · 7 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@esivonxay-cognitiv
Copy link
Contributor

🐛 Bug

I have two datasets which are unbalanced, where one dataset is 1000x larger than the other. I would like to sample from two of the datasets such that the ratio of samples from each is 1:100. When doing so, the batches are of irregular size are returned during iteration.

I think there are 2 issues which this test surfaces:

  1. The first batch returned by each worker is not properly sized.
  2. drop_last does not appear to work as intended, since the last batch is not a full sized batch

I don't think this is related to #179, but it's possible

I've been attempting to fix this, but I'm not sure what the root of the issue is. I would be very appreciative if you could fix this or point me in the right direction.

Thanks!

To Reproduce

@pytest.mark.skipif(sys.platform == "win32", reason="too slow in CI")
def test_unbalanced_combined_dataset_with_dataloader(tmpdir):
    data_dir_1 = os.path.join(tmpdir, "data_1")
    data_dir_2 = os.path.join(tmpdir, "data_2")
    cache_dir_1 = os.path.join(tmpdir, "cache_dir_1")
    cache_dir_2 = os.path.join(tmpdir, "cache_dir_2")

    os.makedirs(data_dir_1)
    os.makedirs(data_dir_2)
    os.makedirs(cache_dir_1)
    os.makedirs(cache_dir_2)

    cache = Cache(input_dir=str(data_dir_1), chunk_size=2)

    for i in range(10):
        cache[i] = i

    cache.done()
    cache.merge()

    cache = Cache(input_dir=str(data_dir_2), chunk_size=2)

    for i in range(10000):
        cache[i] = i + 10

    cache.done()
    cache.merge()

    dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
    dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
    dataset = CombinedStreamingDataset(
        datasets=[dataset1, dataset2], weights=[0.01, 0.99], iterate_over_all=False, seed=12345
    )
    dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=100, drop_last=True, persistent_workers=True, shuffle=True, prefetch_factor=2)

    assert dataset1.current_epoch == 1
    assert dataset2.current_epoch == 1

    batches_1 = []
    batch_sizes_1 = []
    for batch in dataloader:
        batch_sizes_1.append(batch.size(0))
        batches_1.append(batch)

    assert batch_sizes_1[2] == 91
    assert batch_sizes_1[-1] == 40
    # This will fail since the third and last index are not 100. (Above 2 assertions pass)
    assert batch_sizes_1 == [100 for _ in batches_1]

Expected behavior

All batch sizes should be the same.

Additional context

This issue is independent of whether drop_last, shuffle, and persistent_workers are set to True or False

@esivonxay-cognitiv esivonxay-cognitiv added bug Something isn't working help wanted Extra attention is needed labels Jun 29, 2024
@tchaton
Copy link
Collaborator

tchaton commented Jun 30, 2024

Hey @esivonxay-cognitiv, Thanks for the reproducible script. I will have a look into it.

@esivonxay-cognitiv
Copy link
Contributor Author

esivonxay-cognitiv commented Jun 30, 2024

Thanks Thomas!

@tchaton
Copy link
Collaborator

tchaton commented Jun 30, 2024

Hey @esivonxay-cognitiv I am curious, what's your interest and usage of LitData ?

@esivonxay-cognitiv
Copy link
Contributor Author

Yeah, I'm interested in LitData primarily for the ability to sample from multiple streams. I've got 2 datasets which are quite imbalanced (one is 100,000x larger than the other) and I'm trying to downsample one dataset to reduce the imbalance by a couple orders of magnitude.

Naively, I could do this when constructing the dataset by throwing out datapoints. However, doing so will result in me throwing out 90 or 99% of the data (to decrease the imbalance by 10x or 100x, respectively). It's possible that important samples may be thrown out in this process.

My thought was to do this downsampling/rebalancing during dataloading so the model at least has a chance to see each sample, just at a lower rate.

@jackcyc
Copy link

jackcyc commented Jul 10, 2024

I recently encountered a similar issue while training a model with a batch normalization layer. Since batch normalization requires a batch size greater than 1 during training, the training process fails if a batch size of 1 is produced.

There may be a potential solution discussed here, where using drop_last in the DataLoader would cause PyTorch to automatically skip incomplete batches.

However, drop_last is not included in the StreamingDataLoader, and it's not sure if this omission is intentional.

super().__init__(
dataset,
*args,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor,
collate_fn=collate_fn,
**kwargs,
) # type: ignore

@tchaton
Copy link
Collaborator

tchaton commented Jul 11, 2024

Hey @jackcyc @esivonxay-cognitiv,

Would any of you be willing to attempt a fix ? The CombinedDataset isn't well thought IMO and needs to be improved. It was designed for immense training where only a few epochs are made. Your use case is kinda of an edge case.

I think we should re-write it using PyTorch Lightning for inspiration: https://github.com/Lightning-AI/pytorch-lightning/blob/50af052b3129164e28efa8b9321d733311b7b459/src/lightning/pytorch/utilities/combined_loader.py#L222

@esivonxay-cognitiv
Copy link
Contributor Author

Hey Thomas, thanks for the followup.

I haven't looked at the PyTorch Lightning implementation exhaustively, but thanks for bringing it to my attention. I don't currently have the bandwidth for this, but I'll put it on my list of todos and revisit fixing/re-writing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants