Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Added DataCollator for dynamic operations for each batch. (#5221)
Browse files Browse the repository at this point in the history
* ADD: add from_pretrained method for vocab

* MOD: test format

* MOD: format file

* MOD: update changelog

* MOD: fix bug

* MOD: fix bug

* MOD: fix typo

* MOD: make the mothod in class

* MOD: fix bug

* MOD: change to instance method

* MOD: fix typo

* MOD: fix bug

* MOD: change oov to avoid bug

* Update allennlp/data/vocabulary.py

* Update allennlp/data/vocabulary.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/data/vocabulary.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/data/vocabulary.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* MOD: fix formate

* MOD: add test case

* Update CHANGELOG.md

* MOD: fix worker info bug

* ADD: update changelog

* MOD: fix format

* Update allennlp/data/data_loaders/multitask_data_loader.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update CHANGELOG.md

Co-authored-by: Evan Pete Walsh <[email protected]>

* MOD: add demo code

* MOD: align code

* MOD: fix bug

* MOD: fix bug

* MOD: fix bug

* MOD: formate code

* Update allennlp/data/data_loaders/data_collator.py

Co-authored-by: Pete <[email protected]>

* fix error

* MOD: add test code

* mod: change tokenizer

* mod: fix tokenizer

* MOD: fix bug

* MOD: fix bug

* MOD: fix bug

* Update allennlp/data/data_loaders/data_collator.py

Co-authored-by: Dirk Groeneveld <[email protected]>

* MOD: update changelog

* MOD: update change log

* Update allennlp/data/data_loaders/data_collator.py

We should be using underscores for everything.

* Formatting

Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Dirk Groeneveld <[email protected]>
Co-authored-by: Dirk Groeneveld <[email protected]>
  • Loading branch information
4 people authored May 27, 2021
1 parent d97ed40 commit babc450
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased


### Changed

- Use `dist_reduce_sum` in distributed metrics.
Expand All @@ -33,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.
- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling.
- Added `DataCollator` for dynamic operations for each batch.

### Fixed

Expand Down
3 changes: 2 additions & 1 deletion allennlp/data/data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader, WorkerError
from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader
from allennlp.data.data_loaders.data_collator import allennlp_collate
71 changes: 71 additions & 0 deletions allennlp/data/data_loaders/data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List

from transformers.data.data_collator import DataCollatorForLanguageModeling
from allennlp.common import Registrable
from allennlp.data.batch import Batch
from allennlp.data.data_loaders.data_loader import TensorDict
from allennlp.data.instance import Instance


def allennlp_collate(instances: List[Instance]) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()


class DataCollator(Registrable):
"""
This class is similar with `DataCollator` in [Transformers]
(https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py)
Allow to do some dynamic operations for tensor in different batches
Cause this method run before each epoch to convert `List[Instance]` to `TensorDict`
"""

default_implementation = "allennlp"

def __call__(self, instances: List[Instance]) -> TensorDict:
raise NotImplementedError


@DataCollator.register("allennlp")
class DefaultDataCollator(DataCollator):
def __call__(self, instances: List[Instance]) -> TensorDict:
return allennlp_collate(instances)


@DataCollator.register("language_model")
class LanguageModelingDataCollator(DataCollator):
"""
Register as an `DataCollator` with name `LanguageModelingDataCollator`
Used for language modeling.
"""

def __init__(
self,
model_name: str,
mlm: bool = True,
mlm_probability: float = 0.15,
filed_name: str = "source",
namespace: str = "tokens",
):
self._field_name = filed_name
self._namespace = namespace
from allennlp.common import cached_transformers

tokenizer = cached_transformers.get_tokenizer(model_name)
self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability)

def __call__(self, instances: List[Instance]) -> TensorDict:
tensor_dicts = allennlp_collate(instances)
tensor_dicts = self.process_tokens(tensor_dicts)
return tensor_dicts

def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict:
inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"]
inputs, labels = self._collator.mask_tokens(inputs)
tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs
tensor_dicts[self._field_name][self._namespace]["labels"] = labels
return tensor_dicts
12 changes: 1 addition & 11 deletions allennlp/data/data_loaders/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Dict, Union, Iterator
from typing import Dict, Union, Iterator

import torch

from allennlp.common.registrable import Registrable
from allennlp.data.instance import Instance
from allennlp.data.batch import Batch
from allennlp.data.vocabulary import Vocabulary


Expand All @@ -14,15 +13,6 @@
"""


def allennlp_collate(instances: List[Instance]) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()


class DataLoader(Registrable):
"""
A `DataLoader` is responsible for generating batches of instances from a
Expand Down
8 changes: 6 additions & 2 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from allennlp.common.util import lazy_groups_of, shuffle_iterable
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.data_collator import DataCollator, DefaultDataCollator
from allennlp.data.dataset_readers import DatasetReader, WorkerInfo, DatasetReaderInput
from allennlp.data.fields import TextField
from allennlp.data.samplers import BatchSampler
Expand Down Expand Up @@ -124,6 +125,8 @@ class MultiProcessDataLoader(DataLoader):
quiet : `bool`, optional (default = `False`)
If `True`, tqdm progress bars will be disabled.
collate_fn : `DataCollator`, optional ( default = `DefaultDataCollator`)
# Best practices
- **Large datasets**
Expand Down Expand Up @@ -207,6 +210,7 @@ def __init__(
start_method: str = "fork",
cuda_device: Optional[Union[int, str, torch.device]] = None,
quiet: bool = False,
collate_fn: DataCollator = DefaultDataCollator(),
) -> None:
# Do some parameter validation.
if num_workers is not None and num_workers < 0:
Expand Down Expand Up @@ -244,7 +248,7 @@ def __init__(
self.batch_sampler = batch_sampler
self.batches_per_epoch = batches_per_epoch
self.num_workers = num_workers
self.collate_fn = allennlp_collate
self.collate_fn = collate_fn
self.max_instances_in_memory = max_instances_in_memory
self.start_method = start_method
self.quiet = quiet
Expand Down
6 changes: 4 additions & 2 deletions allennlp/data/data_loaders/simple_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from allennlp.common.util import lazy_groups_of
from allennlp.common.tqdm import Tqdm
from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate, TensorDict
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.data_collator import DefaultDataCollator
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.vocabulary import Vocabulary
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
self.vocab = vocab
self.cuda_device: Optional[torch.device] = None
self._batch_generator: Optional[Iterator[TensorDict]] = None
self.collate_fn = DefaultDataCollator()

def __len__(self) -> int:
if self.batches_per_epoch is not None:
Expand All @@ -60,7 +62,7 @@ def _iter_batches(self) -> Iterator[TensorDict]:
if self.shuffle:
random.shuffle(self.instances)
for batch in lazy_groups_of(self.iter_instances(), self.batch_size):
tensor_dict = allennlp_collate(batch)
tensor_dict = self.collate_fn(batch)
if self.cuda_device is not None:
tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device)
yield tensor_dict
Expand Down
27 changes: 27 additions & 0 deletions tests/data/data_loaders/multiprocess_data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.data_loaders.data_collator import LanguageModelingDataCollator


class MockDatasetReader(DatasetReader):
Expand Down Expand Up @@ -166,6 +167,32 @@ def test_drop_last():
assert len(batches) == 6


def test_language_model_data_collator():
"""
Ensure `LanguageModelingDataCollator` works
"""
norm_loader = MultiProcessDataLoader(MockDatasetReader(), "some path", batch_size=16)
vocab = Vocabulary.from_instances(norm_loader.iter_instances())
norm_loader.index_with(vocab)
batch0 = list(norm_loader)[0]

model_name = "epwalsh/bert-xsmall-dummy"
data_collate = LanguageModelingDataCollator(model_name)
mlm_loader = MultiProcessDataLoader(
MockDatasetReader(), "some path", batch_size=16, collate_fn=data_collate
)
vocab = Vocabulary.from_instances(mlm_loader.iter_instances())
mlm_loader.index_with(vocab)
batch1 = list(mlm_loader)[0]

norm_inputs = batch0["source"]["tokens"]["token_ids"]
mlm_inputs = batch1["source"]["tokens"]["token_ids"]
mlm_labels = batch1["source"]["tokens"]["labels"]

# if we replace the mlm inputs with their labels, should be same as origin inputs
assert torch.where(mlm_labels != -100, mlm_labels, mlm_inputs).tolist() == norm_inputs.tolist()


def test_batches_per_epoch():
loader = MultiProcessDataLoader(
MockDatasetReader(), "some path", batch_size=4, batches_per_epoch=10
Expand Down

0 comments on commit babc450

Please sign in to comment.