Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Added DataCollator for dynamic operations for each batch. #5221

Merged
merged 56 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d691c46
ADD: add from_pretrained method for vocab
wlhgtc Feb 3, 2021
582815c
MOD: test format
wlhgtc Feb 3, 2021
88ebacc
MOD: format file
wlhgtc Feb 3, 2021
5ea07ec
MOD: update changelog
wlhgtc Feb 3, 2021
3f4b9b8
MOD: fix bug
wlhgtc Feb 4, 2021
ea9cbf2
MOD: fix bug
wlhgtc Feb 4, 2021
76bba4d
MOD: fix typo
wlhgtc Feb 4, 2021
fbe1def
MOD: make the mothod in class
wlhgtc Feb 6, 2021
d24a54a
MOD: fix bug
wlhgtc Feb 6, 2021
9dcfc31
MOD: change to instance method
wlhgtc Feb 7, 2021
1a9d257
MOD: fix typo
wlhgtc Feb 7, 2021
3d4be33
MOD: fix bug
wlhgtc Feb 7, 2021
1c897c5
MOD: change oov to avoid bug
wlhgtc Feb 8, 2021
04d9f97
Merge branch 'main' into main
epwalsh Feb 8, 2021
0c153ef
Update allennlp/data/vocabulary.py
epwalsh Feb 8, 2021
334e53c
Update allennlp/data/vocabulary.py
wlhgtc Feb 8, 2021
21cbf38
Update allennlp/data/vocabulary.py
wlhgtc Feb 8, 2021
7a9f2f3
Update allennlp/data/vocabulary.py
wlhgtc Feb 8, 2021
6b35ca9
MOD: fix formate
wlhgtc Feb 8, 2021
c7fec2e
MOD: add test case
wlhgtc Feb 9, 2021
6e35025
Update CHANGELOG.md
epwalsh Feb 9, 2021
0f3cfa2
MOD: align to upstream
wlhgtc Feb 23, 2021
2ee0019
MOD : fix conflict
wlhgtc Feb 23, 2021
8cee9ee
MOD: fix worker info bug
wlhgtc Feb 23, 2021
7b6e70f
ADD: update changelog
wlhgtc Feb 23, 2021
8f17508
MOD: fix format
wlhgtc Feb 23, 2021
a476490
Update allennlp/data/data_loaders/multitask_data_loader.py
wlhgtc Feb 24, 2021
617564d
Update CHANGELOG.md
wlhgtc Feb 24, 2021
0afcc4b
Merge branch 'main' into main
epwalsh Feb 24, 2021
127443e
Merge branch 'main' into main
dirkgr Feb 24, 2021
a3d47cd
MOD: fix conflict
wlhgtc May 25, 2021
ffd2308
MOD: fix conflcit
wlhgtc May 25, 2021
18c35b6
MOD: add demo code
wlhgtc May 25, 2021
cd4bf9f
MOD: align code
wlhgtc May 25, 2021
66659c8
MOD: fix bug
wlhgtc May 25, 2021
6325327
MOD: fix bug
wlhgtc May 25, 2021
86a8e24
MOD: fix bug
wlhgtc May 25, 2021
127b80d
MOD: formate code
wlhgtc May 25, 2021
251546d
Merge branch 'main' into main
dirkgr May 26, 2021
d4b17f6
Update allennlp/data/data_loaders/data_collator.py
wlhgtc May 26, 2021
781d0b2
fix error
wlhgtc May 26, 2021
68f25f0
MOD: add test code
wlhgtc May 26, 2021
412fc06
mod: change tokenizer
wlhgtc May 26, 2021
d6cde24
mod: fix tokenizer
wlhgtc May 26, 2021
c0c5fc7
MOD: fix bug
wlhgtc May 26, 2021
7fe1435
MOD: fix bug
wlhgtc May 26, 2021
155d8dc
MOD: fix bug
wlhgtc May 26, 2021
c126742
Merge branch 'main' into main
dirkgr May 26, 2021
5641ee5
Merge branch 'main' into main
dirkgr May 26, 2021
56bdcdb
Update allennlp/data/data_loaders/data_collator.py
wlhgtc May 27, 2021
119ba13
MOD: update changelog
wlhgtc May 27, 2021
5089111
MOD: update change log
wlhgtc May 27, 2021
d83e5c3
Merge branch 'main' into main
dirkgr May 27, 2021
36522c0
Merge branch 'main' into main
dirkgr May 27, 2021
dbd9e08
Update allennlp/data/data_loaders/data_collator.py
dirkgr May 27, 2021
b73bc61
Formatting
dirkgr May 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased


### Changed

- Use `dist_reduce_sum` in distributed metrics.
Expand All @@ -33,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.
- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling.
- Added `DataCollator` for dynamic operations for each batch.

### Fixed

Expand Down
3 changes: 2 additions & 1 deletion allennlp/data/data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader, WorkerError
from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader
from allennlp.data.data_loaders.data_collator import allennlp_collate
71 changes: 71 additions & 0 deletions allennlp/data/data_loaders/data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List

from transformers.data.data_collator import DataCollatorForLanguageModeling
from allennlp.common import Registrable
from allennlp.data.batch import Batch
from allennlp.data.data_loaders.data_loader import TensorDict
from allennlp.data.instance import Instance


def allennlp_collate(instances: List[Instance]) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()


class DataCollator(Registrable):
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
"""
This class is similar with `DataCollator` in [Transformers]
(https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py)
Allow to do some dynamic operations for tensor in different batches
Cause this method run before each epoch to convert `List[Instance]` to `TensorDict`
"""

default_implementation = "allennlp"

def __call__(self, instances: List[Instance]) -> TensorDict:
raise NotImplementedError


wlhgtc marked this conversation as resolved.
Show resolved Hide resolved
@DataCollator.register("allennlp")
class DefaultDataCollator(DataCollator):
def __call__(self, instances: List[Instance]) -> TensorDict:
return allennlp_collate(instances)


@DataCollator.register("language_model")
class LanguageModelingDataCollator(DataCollator):
"""
Register as an `DataCollator` with name `LanguageModelingDataCollator`
Used for language modeling.
"""

def __init__(
self,
model_name: str,
mlm: bool = True,
mlm_probability: float = 0.15,
filed_name: str = "source",
namespace: str = "tokens",
):
self._field_name = filed_name
self._namespace = namespace
from allennlp.common import cached_transformers

tokenizer = cached_transformers.get_tokenizer(model_name)
self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability)

def __call__(self, instances: List[Instance]) -> TensorDict:
tensor_dicts = allennlp_collate(instances)
tensor_dicts = self.process_tokens(tensor_dicts)
return tensor_dicts

def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict:
inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"]
inputs, labels = self._collator.mask_tokens(inputs)
tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs
tensor_dicts[self._field_name][self._namespace]["labels"] = labels
return tensor_dicts
12 changes: 1 addition & 11 deletions allennlp/data/data_loaders/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Dict, Union, Iterator
from typing import Dict, Union, Iterator

import torch

from allennlp.common.registrable import Registrable
from allennlp.data.instance import Instance
from allennlp.data.batch import Batch
from allennlp.data.vocabulary import Vocabulary


Expand All @@ -14,15 +13,6 @@
"""


def allennlp_collate(instances: List[Instance]) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()


class DataLoader(Registrable):
"""
A `DataLoader` is responsible for generating batches of instances from a
Expand Down
8 changes: 6 additions & 2 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from allennlp.common.util import lazy_groups_of, shuffle_iterable
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.data_collator import DataCollator, DefaultDataCollator
from allennlp.data.dataset_readers import DatasetReader, WorkerInfo, DatasetReaderInput
from allennlp.data.fields import TextField
from allennlp.data.samplers import BatchSampler
Expand Down Expand Up @@ -124,6 +125,8 @@ class MultiProcessDataLoader(DataLoader):
quiet : `bool`, optional (default = `False`)
If `True`, tqdm progress bars will be disabled.

collate_fn : `DataCollator`, optional ( default = `DefaultDataCollator`)

# Best practices

- **Large datasets**
Expand Down Expand Up @@ -207,6 +210,7 @@ def __init__(
start_method: str = "fork",
cuda_device: Optional[Union[int, str, torch.device]] = None,
quiet: bool = False,
collate_fn: DataCollator = DefaultDataCollator(),
) -> None:
# Do some parameter validation.
if num_workers is not None and num_workers < 0:
Expand Down Expand Up @@ -244,7 +248,7 @@ def __init__(
self.batch_sampler = batch_sampler
self.batches_per_epoch = batches_per_epoch
self.num_workers = num_workers
self.collate_fn = allennlp_collate
self.collate_fn = collate_fn
self.max_instances_in_memory = max_instances_in_memory
self.start_method = start_method
self.quiet = quiet
Expand Down
6 changes: 4 additions & 2 deletions allennlp/data/data_loaders/simple_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from allennlp.common.util import lazy_groups_of
from allennlp.common.tqdm import Tqdm
from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate, TensorDict
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.data.data_loaders.data_collator import DefaultDataCollator
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.vocabulary import Vocabulary
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
self.vocab = vocab
self.cuda_device: Optional[torch.device] = None
self._batch_generator: Optional[Iterator[TensorDict]] = None
self.collate_fn = DefaultDataCollator()

def __len__(self) -> int:
if self.batches_per_epoch is not None:
Expand All @@ -60,7 +62,7 @@ def _iter_batches(self) -> Iterator[TensorDict]:
if self.shuffle:
random.shuffle(self.instances)
for batch in lazy_groups_of(self.iter_instances(), self.batch_size):
tensor_dict = allennlp_collate(batch)
tensor_dict = self.collate_fn(batch)
if self.cuda_device is not None:
tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device)
yield tensor_dict
Expand Down
27 changes: 27 additions & 0 deletions tests/data/data_loaders/multiprocess_data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.data_loaders.data_collator import LanguageModelingDataCollator


class MockDatasetReader(DatasetReader):
Expand Down Expand Up @@ -166,6 +167,32 @@ def test_drop_last():
assert len(batches) == 6


def test_language_model_data_collator():
"""
Ensure `LanguageModelingDataCollator` works
"""
norm_loader = MultiProcessDataLoader(MockDatasetReader(), "some path", batch_size=16)
vocab = Vocabulary.from_instances(norm_loader.iter_instances())
norm_loader.index_with(vocab)
batch0 = list(norm_loader)[0]

model_name = "epwalsh/bert-xsmall-dummy"
data_collate = LanguageModelingDataCollator(model_name)
mlm_loader = MultiProcessDataLoader(
MockDatasetReader(), "some path", batch_size=16, collate_fn=data_collate
)
vocab = Vocabulary.from_instances(mlm_loader.iter_instances())
mlm_loader.index_with(vocab)
batch1 = list(mlm_loader)[0]

norm_inputs = batch0["source"]["tokens"]["token_ids"]
mlm_inputs = batch1["source"]["tokens"]["token_ids"]
mlm_labels = batch1["source"]["tokens"]["labels"]

# if we replace the mlm inputs with their labels, should be same as origin inputs
assert torch.where(mlm_labels != -100, mlm_labels, mlm_inputs).tolist() == norm_inputs.tolist()


def test_batches_per_epoch():
loader = MultiProcessDataLoader(
MockDatasetReader(), "some path", batch_size=4, batches_per_epoch=10
Expand Down