Skip to content

Commit

Permalink
added shuffle disable option in BucketBatchSampler (allenai#5212)
Browse files Browse the repository at this point in the history
* added shuffle disable option in BucketBatchSampler

* Update allennlp/data/samplers/bucket_batch_sampler.py

Co-authored-by: Pete <[email protected]>

Co-authored-by: Arjun Subramonian <[email protected]>
Co-authored-by: Pete <[email protected]>
  • Loading branch information
3 people authored and Abhishek-P committed Aug 11, 2021
1 parent 9de5b4e commit 5660670
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.
- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling.

### Fixed

Expand Down
11 changes: 10 additions & 1 deletion allennlp/data/samplers/bucket_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class BucketBatchSampler(BatchSampler):
If `True`, the sampler will drop the last batch if
its size would be less than batch_size`.
shuffle : `bool`, (default = `True`)
If `False`, the sampler won't shuffle the batches. `padding_noise` will be ignored and set
to `0.0`.
"""

def __init__(
Expand All @@ -65,11 +69,15 @@ def __init__(
sorting_keys: List[str] = None,
padding_noise: float = 0.1,
drop_last: bool = False,
shuffle: bool = True,
):
self.sorting_keys = sorting_keys
self.padding_noise = padding_noise
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
if not shuffle:
self.padding_noise = 0.0

def _argsort_by_padding(
self, instances: Iterable[Instance]
Expand Down Expand Up @@ -113,7 +121,8 @@ def get_batch_indices(self, instances: Sequence[Instance]) -> Iterable[List[int]
if self.drop_last and len(batch_indices) < self.batch_size:
continue
batches.append(batch_indices)
random.shuffle(batches)
if self.shuffle:
random.shuffle(batches)
for batch in batches:
yield batch

Expand Down
14 changes: 14 additions & 0 deletions tests/data/samplers/bucket_batch_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def test_create_batches_groups_correctly(self):
expected_groups.remove(group)
assert expected_groups == []

def test_disable_shuffle(self):
sampler = BucketBatchSampler(batch_size=2, sorting_keys=["text"], shuffle=False)

grouped_instances = []
for indices in sampler.get_batch_indices(self.instances):
grouped_instances.append([self.instances[idx] for idx in indices])
expected_groups = [
[self.instances[4], self.instances[2]],
[self.instances[0], self.instances[1]],
[self.instances[3]],
]
for idx, group in enumerate(grouped_instances):
assert group == expected_groups[idx]

def test_guess_sorting_key_picks_the_longest_key(self):
sampler = BucketBatchSampler(batch_size=2, padding_noise=0)
instances = []
Expand Down

0 comments on commit 5660670

Please sign in to comment.