This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added
DataCollator
for dynamic operations for each batch. (#5221)
* ADD: add from_pretrained method for vocab * MOD: test format * MOD: format file * MOD: update changelog * MOD: fix bug * MOD: fix bug * MOD: fix typo * MOD: make the mothod in class * MOD: fix bug * MOD: change to instance method * MOD: fix typo * MOD: fix bug * MOD: change oov to avoid bug * Update allennlp/data/vocabulary.py * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <[email protected]> * MOD: fix formate * MOD: add test case * Update CHANGELOG.md * MOD: fix worker info bug * ADD: update changelog * MOD: fix format * Update allennlp/data/data_loaders/multitask_data_loader.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update CHANGELOG.md Co-authored-by: Evan Pete Walsh <[email protected]> * MOD: add demo code * MOD: align code * MOD: fix bug * MOD: fix bug * MOD: fix bug * MOD: formate code * Update allennlp/data/data_loaders/data_collator.py Co-authored-by: Pete <[email protected]> * fix error * MOD: add test code * mod: change tokenizer * mod: fix tokenizer * MOD: fix bug * MOD: fix bug * MOD: fix bug * Update allennlp/data/data_loaders/data_collator.py Co-authored-by: Dirk Groeneveld <[email protected]> * MOD: update changelog * MOD: update change log * Update allennlp/data/data_loaders/data_collator.py We should be using underscores for everything. * Formatting Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Dirk Groeneveld <[email protected]> Co-authored-by: Dirk Groeneveld <[email protected]>
- Loading branch information
1 parent
d97ed40
commit babc450
Showing
7 changed files
with
113 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate | ||
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict | ||
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader, WorkerError | ||
from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader | ||
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader | ||
from allennlp.data.data_loaders.data_collator import allennlp_collate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import List | ||
|
||
from transformers.data.data_collator import DataCollatorForLanguageModeling | ||
from allennlp.common import Registrable | ||
from allennlp.data.batch import Batch | ||
from allennlp.data.data_loaders.data_loader import TensorDict | ||
from allennlp.data.instance import Instance | ||
|
||
|
||
def allennlp_collate(instances: List[Instance]) -> TensorDict: | ||
""" | ||
This is the default function used to turn a list of `Instance`s into a `TensorDict` | ||
batch. | ||
""" | ||
batch = Batch(instances) | ||
return batch.as_tensor_dict() | ||
|
||
|
||
class DataCollator(Registrable): | ||
""" | ||
This class is similar with `DataCollator` in [Transformers] | ||
(https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py) | ||
Allow to do some dynamic operations for tensor in different batches | ||
Cause this method run before each epoch to convert `List[Instance]` to `TensorDict` | ||
""" | ||
|
||
default_implementation = "allennlp" | ||
|
||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
raise NotImplementedError | ||
|
||
|
||
@DataCollator.register("allennlp") | ||
class DefaultDataCollator(DataCollator): | ||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
return allennlp_collate(instances) | ||
|
||
|
||
@DataCollator.register("language_model") | ||
class LanguageModelingDataCollator(DataCollator): | ||
""" | ||
Register as an `DataCollator` with name `LanguageModelingDataCollator` | ||
Used for language modeling. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
mlm: bool = True, | ||
mlm_probability: float = 0.15, | ||
filed_name: str = "source", | ||
namespace: str = "tokens", | ||
): | ||
self._field_name = filed_name | ||
self._namespace = namespace | ||
from allennlp.common import cached_transformers | ||
|
||
tokenizer = cached_transformers.get_tokenizer(model_name) | ||
self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability) | ||
|
||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
tensor_dicts = allennlp_collate(instances) | ||
tensor_dicts = self.process_tokens(tensor_dicts) | ||
return tensor_dicts | ||
|
||
def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict: | ||
inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"] | ||
inputs, labels = self._collator.mask_tokens(inputs) | ||
tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs | ||
tensor_dicts[self._field_name][self._namespace]["labels"] = labels | ||
return tensor_dicts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters