Skip to content

Commit

Permalink
[Tokenier] Enable padding_side as call time kwargs (#9161)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvdongyi authored Sep 24, 2024
1 parent cd4e816 commit c5e6db5
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 57 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/transformers/artist/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __call__(
return_offsets_mapping=False,
add_special_tokens=True,
pad_to_multiple_of=None,
padding_side=None,
return_tensors=None,
verbose: bool = True,
**kwargs
Expand All @@ -247,6 +248,7 @@ def __call__(
return_offsets_mapping,
add_special_tokens,
pad_to_multiple_of,
padding_side,
return_tensors,
verbose,
**kwargs,
Expand Down
10 changes: 7 additions & 3 deletions paddlenlp/transformers/bloom/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import shutil
from functools import lru_cache
from typing import Dict, Optional, Union
from typing import Dict, Literal, Optional, Union

import numpy as np
from paddle.utils import try_import
Expand Down Expand Up @@ -360,6 +360,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -375,13 +376,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -394,7 +398,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
6 changes: 4 additions & 2 deletions paddlenlp/transformers/chatglm/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Tokenization classes for ChatGLM."""
import os
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union

import numpy as np
import sentencepiece as spm
Expand Down Expand Up @@ -218,13 +218,15 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy=PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names or "attention_mask" in encoded_inputs

assert self.padding_side == "left"
padding_side = padding_side if padding_side is not None else self.padding_side
assert padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)

Expand Down
13 changes: 9 additions & 4 deletions paddlenlp/transformers/chatglm_v2/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import os
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import numpy as np
from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -244,6 +244,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -259,18 +260,22 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
padding_side = padding_side if padding_side is not None else self.padding_side
assert padding_side == "left"

required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/dallebart/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def __call__(
return_offsets_mapping=False,
add_special_tokens=True,
pad_to_multiple_of=None,
padding_side=None,
return_tensors=None,
verbose: bool = True,
**kwargs
Expand Down Expand Up @@ -497,6 +498,7 @@ def __call__(
return_offsets_mapping,
add_special_tokens,
pad_to_multiple_of,
padding_side,
return_tensors,
verbose,
**kwargs,
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/gemma/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import sentencepiece as spm
Expand Down Expand Up @@ -323,6 +323,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -345,6 +346,9 @@ def _pad(
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -359,7 +363,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
10 changes: 7 additions & 3 deletions paddlenlp/transformers/gpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import shutil
from functools import lru_cache
from typing import Dict, Optional, Union
from typing import Dict, Literal, Optional, Union

import jieba
import numpy as np
Expand Down Expand Up @@ -584,6 +584,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -599,13 +600,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -620,7 +624,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
18 changes: 13 additions & 5 deletions paddlenlp/transformers/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import os
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import sentencepiece as spm
Expand Down Expand Up @@ -232,6 +232,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -247,13 +248,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -268,7 +272,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down Expand Up @@ -521,6 +525,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -536,13 +541,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -557,7 +565,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
10 changes: 7 additions & 3 deletions paddlenlp/transformers/mamba/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import shutil
from functools import lru_cache
from typing import Dict, Optional, Union
from typing import Dict, Literal, Optional, Union

import numpy as np
from paddle.utils import try_import
Expand Down Expand Up @@ -302,6 +302,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -317,13 +318,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -338,7 +342,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
10 changes: 7 additions & 3 deletions paddlenlp/transformers/qwen/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import base64
import os
import unicodedata
from typing import Collection, Dict, List, Optional, Set, Tuple, Union
from typing import Collection, Dict, List, Literal, Optional, Set, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -255,6 +255,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[Literal["right", "left"]] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -270,13 +271,16 @@ def _pad(
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
The tokenizer padding sides are defined in `padding_side` argument:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
padding_side: (optional) The side on which the model should have padding applied.
Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -291,7 +295,7 @@ def _pad(

required_input = encoded_inputs[self.model_input_names[0]]
encoded_inputs = super()._pad(
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask
encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask
)
if attention_mask is not None and len(np.shape(attention_mask)) > 2:
encoded_inputs["attention_mask"] = attention_mask
Expand Down
Loading

0 comments on commit c5e6db5

Please sign in to comment.