Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SentencePieceTokenizer and LlamaTokenizer #1206

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions merlin/models/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer # noqa: F401
from merlin.models.tokenizers.tokenizer import Tokenizer # noqa: F401
56 changes: 56 additions & 0 deletions merlin/models/tokenizers/sentencepiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List

from merlin.models.tokenizers.tokenizer import Tokenizer


class SentencePieceTokenizer(Tokenizer):
"""Tokenizer using SentencePiece [1].

References
----------
[1] https://github.com/google/sentencepiece
"""

def __init__(self, *, processor: "SentencePieceTrainer") -> None: # noqa: F821
require_sentencepiece()

self.processor = processor
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()

def encode(
self,
string: str,
bos: bool = False,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
) -> List[int]:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))

return tokens

def decode(self, tokens: List[int]) -> str:
return self.processor.decode(tokens)

@property
def vocab_size(self) -> int:
return self.processor.vocab_size()


def require_sentencepiece() -> None:
try:
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer # noqa: F401
except ImportError:
raise ImportError(
"This requires `sentencepiece`. Install it with `pip install sentencepiece`."
)
19 changes: 19 additions & 0 deletions merlin/models/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from typing import List


class Tokenizer(ABC):
"""
Base class for all tokenizers.
"""

def __call__(self, string: str):
return self.encode(string)

@abstractmethod
def decode(self, tokens: List[int]):
...

@abstractmethod
def encode(self, string: str):
...
65 changes: 65 additions & 0 deletions merlin/models/torch/blocks/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from pathlib import Path
from typing import Optional

import torch
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer


class SentencePieceTokenizer:
"""Tokenizer for LLaMA.

Example usage
-------------
>> tokenizer_path = Path("llama/tokenizer.model")
>> tokenizer = SentencePieceTokenizer(tokenizer_path)
>> tokenizer.encode("Hello, my name is", bos=True, eos=False)
tensor([ 1, 15043, 29892, 590, 1024, 338], dtype=torch.int32)
"""

def __init__(self, model_path: Path) -> None:
try:
import sentencepiece # noqa: F401
except ImportError:
raise ImportError(
"`sentencepiece` is required to use this feature. "
"Install it with `pip install sentencepiece`."
)

self.processor = SentencePieceProcessor(model_file=str(model_path))
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()

@property
def vocab_size(self) -> int:
return self.processor.vocab_size()

def encode(
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None,
) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))

return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())

@staticmethod
def train(input: str, destination: str, vocab_size=32000) -> None:
model_prefix = os.path.join(destination, "tokenizer")
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
Empty file.
40 changes: 40 additions & 0 deletions merlin/models/torch/tokenizers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path
from typing import Optional, Union

import torch

from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer, require_sentencepiece


class LlamaTokenizer(SentencePieceTokenizer):
def __init__(self, path: Union[str, Path]) -> None:
require_sentencepiece()

from sentencepiece import SentencePieceProcessor

if isinstance(path, Path):
path = str(path)
processor = SentencePieceProcessor(model_file=str(path))

super().__init__(processor=processor)

def endode(
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None,
) -> torch.Tensor:
tokens = super().encode(
string=string,
bos=bos,
eos=eos,
max_length=max_length,
pad=pad,
)
return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())
Loading