This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Added DataCollator
for dynamic operations for each batch.
#5221
Merged
Merged
Changes from 48 commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
d691c46
ADD: add from_pretrained method for vocab
wlhgtc 582815c
MOD: test format
wlhgtc 88ebacc
MOD: format file
wlhgtc 5ea07ec
MOD: update changelog
wlhgtc 3f4b9b8
MOD: fix bug
wlhgtc ea9cbf2
MOD: fix bug
wlhgtc 76bba4d
MOD: fix typo
wlhgtc fbe1def
MOD: make the mothod in class
wlhgtc d24a54a
MOD: fix bug
wlhgtc 9dcfc31
MOD: change to instance method
wlhgtc 1a9d257
MOD: fix typo
wlhgtc 3d4be33
MOD: fix bug
wlhgtc 1c897c5
MOD: change oov to avoid bug
wlhgtc 04d9f97
Merge branch 'main' into main
epwalsh 0c153ef
Update allennlp/data/vocabulary.py
epwalsh 334e53c
Update allennlp/data/vocabulary.py
wlhgtc 21cbf38
Update allennlp/data/vocabulary.py
wlhgtc 7a9f2f3
Update allennlp/data/vocabulary.py
wlhgtc 6b35ca9
MOD: fix formate
wlhgtc c7fec2e
MOD: add test case
wlhgtc 6e35025
Update CHANGELOG.md
epwalsh 0f3cfa2
MOD: align to upstream
wlhgtc 2ee0019
MOD : fix conflict
wlhgtc 8cee9ee
MOD: fix worker info bug
wlhgtc 7b6e70f
ADD: update changelog
wlhgtc 8f17508
MOD: fix format
wlhgtc a476490
Update allennlp/data/data_loaders/multitask_data_loader.py
wlhgtc 617564d
Update CHANGELOG.md
wlhgtc 0afcc4b
Merge branch 'main' into main
epwalsh 127443e
Merge branch 'main' into main
dirkgr a3d47cd
MOD: fix conflict
wlhgtc ffd2308
MOD: fix conflcit
wlhgtc 18c35b6
MOD: add demo code
wlhgtc cd4bf9f
MOD: align code
wlhgtc 66659c8
MOD: fix bug
wlhgtc 6325327
MOD: fix bug
wlhgtc 86a8e24
MOD: fix bug
wlhgtc 127b80d
MOD: formate code
wlhgtc 251546d
Merge branch 'main' into main
dirkgr d4b17f6
Update allennlp/data/data_loaders/data_collator.py
wlhgtc 781d0b2
fix error
wlhgtc 68f25f0
MOD: add test code
wlhgtc 412fc06
mod: change tokenizer
wlhgtc d6cde24
mod: fix tokenizer
wlhgtc c0c5fc7
MOD: fix bug
wlhgtc 7fe1435
MOD: fix bug
wlhgtc 155d8dc
MOD: fix bug
wlhgtc c126742
Merge branch 'main' into main
dirkgr 5641ee5
Merge branch 'main' into main
dirkgr 56bdcdb
Update allennlp/data/data_loaders/data_collator.py
wlhgtc 119ba13
MOD: update changelog
wlhgtc 5089111
MOD: update change log
wlhgtc d83e5c3
Merge branch 'main' into main
dirkgr 36522c0
Merge branch 'main' into main
dirkgr dbd9e08
Update allennlp/data/data_loaders/data_collator.py
dirkgr b73bc61
Formatting
dirkgr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate | ||
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict | ||
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader, WorkerError | ||
from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader | ||
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader | ||
from allennlp.data.data_loaders.data_collator import allennlp_collate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import List | ||
|
||
from transformers.data.data_collator import DataCollatorForLanguageModeling | ||
from allennlp.common import Registrable | ||
from allennlp.data.batch import Batch | ||
from allennlp.data.data_loaders.data_loader import TensorDict | ||
from allennlp.data.instance import Instance | ||
|
||
|
||
def allennlp_collate(instances: List[Instance]) -> TensorDict: | ||
""" | ||
This is the default function used to turn a list of `Instance`s into a `TensorDict` | ||
batch. | ||
""" | ||
batch = Batch(instances) | ||
return batch.as_tensor_dict() | ||
|
||
|
||
class DataCollator(Registrable): | ||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
This class is similar with `DataCollator` in [Transformers] | ||
(https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py) | ||
Allow to do some dynamic operations for tensor in different batches | ||
Cause this method run before each epoch to convert `List[Instance]` to `TensorDict` | ||
""" | ||
|
||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
raise NotImplementedError | ||
|
||
|
||
wlhgtc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class DefaultDataCollator(DataCollator): | ||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
return allennlp_collate(instances) | ||
|
||
|
||
@DataCollator.register("language-model") | ||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class LanguageModelingDataCollator(DataCollator): | ||
""" | ||
Register as an `DataCollator` with name `LanguageModelingDataCollator` | ||
Used for language modeling. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
mlm: bool = True, | ||
mlm_probability: float = 0.15, | ||
filed_name: str = "source", | ||
namespace: str = "tokens", | ||
): | ||
self._field_name = filed_name | ||
self._namespace = namespace | ||
from allennlp.common import cached_transformers | ||
|
||
tokenizer = cached_transformers.get_tokenizer(model_name) | ||
self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability) | ||
|
||
def __call__(self, instances: List[Instance]) -> TensorDict: | ||
tensor_dicts = allennlp_collate(instances) | ||
tensor_dicts = self.process_tokens(tensor_dicts) | ||
return tensor_dicts | ||
|
||
def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict: | ||
inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"] | ||
inputs, labels = self._collator.mask_tokens(inputs) | ||
tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs | ||
tensor_dicts[self._field_name][self._namespace]["labels"] = labels | ||
return tensor_dicts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already an "Added" section below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wlhgtc, can you move this to the "Added" section below?