diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fd7f3c1cf7901f..1a7dddba4e7071 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -104,6 +104,7 @@ "DataCollatorForTokenClassification", "DataCollatorForWholeWordMask", "DataCollatorWithPadding", + "DataCollatorWithFlattening", "DefaultDataCollator", "default_data_collator", ], @@ -4692,6 +4693,7 @@ DataCollatorForTokenClassification, DataCollatorForWholeWordMask, DataCollatorWithPadding, + DataCollatorWithFlattening, DefaultDataCollator, default_data_collator, ) diff --git a/src/transformers/data/__init__.py b/src/transformers/data/__init__.py index 1a8ef35ff439e4..10488a0a53f6b3 100644 --- a/src/transformers/data/__init__.py +++ b/src/transformers/data/__init__.py @@ -20,6 +20,7 @@ DataCollatorForTokenClassification, DataCollatorForWholeWordMask, DataCollatorWithPadding, + DataCollatorWithFlattening, DefaultDataCollator, default_data_collator, ) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 1961c387c2dc58..822bfb8f71b10b 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1613,7 +1613,7 @@ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]: return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64) @dataclass -class DataCollatorBatchFlattening(DefaultDataCollator): +class DataCollatorWithFlattening(DefaultDataCollator): return_tensors: str = "pt" def __init__(self, return_position_ids=True): self.return_position_ids=return_position_ids