Skip to content

Commit

Permalink
Add SentencePieceTokenier and LlamaTokenier
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Aug 1, 2023
1 parent a1d0be2 commit 8c3e544
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 0 deletions.
2 changes: 2 additions & 0 deletions merlin/models/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from merlin.models.tokenizers.tokenizer import Tokenizer # noqa: F401
from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer # noqa: F401
56 changes: 56 additions & 0 deletions merlin/models/tokenizers/sentencepiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List

from merlin.models.tokenizers.tokenizer import Tokenizer


class SentencePieceTokenizer(Tokenizer):
"""Tokenizer using SentencePiece [1].
References
----------
[1] https://github.com/google/sentencepiece
"""

def __init__(self, *, processor: "SentencePieceTrainer") -> None:
require_sentencepiece()

self.processor = processor
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()

def encode(
self,
string: str,
bos: bool = False,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
) -> List[int]:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))

return tokens

def decode(self, tokens: List[int]) -> str:
return self.processor.decode(tokens)

@property
def vocab_size(self) -> int:
return self.processor.vocab_size()


def require_sentencepiece() -> None:
try:
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer # noqa: F401
except ImportError:
raise ImportError(
"This requires `sentencepiece`. Install it with `pip install sentencepiece`."
)
19 changes: 19 additions & 0 deletions merlin/models/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from typing import List


class Tokenizer(ABC):
"""
Base class for all tokenizers.
"""

def __call__(self, string: str):
return self.encode(string)

@abstractmethod
def decode(self, tokens: List[int]):
...

@abstractmethod
def encode(self, string: str):
...
65 changes: 65 additions & 0 deletions merlin/models/torch/blocks/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from pathlib import Path
from typing import Optional

import torch
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer


class SentencePieceTokenizer:
"""Tokenizer for LLaMA.
Example usage
-------------
>> tokenizer_path = Path("llama/tokenizer.model")
>> tokenizer = SentencePieceTokenizer(tokenizer_path)
>> tokenizer.encode("Hello, my name is", bos=True, eos=False)
tensor([ 1, 15043, 29892, 590, 1024, 338], dtype=torch.int32)
"""

def __init__(self, model_path: Path) -> None:
try:
import sentencepiece # noqa: F401
except ImportError:
raise ImportError(
"`sentencepiece` is required to use this feature. "
"Install it with `pip install sentencepiece`."
)

self.processor = SentencePieceProcessor(model_file=str(model_path))
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()

@property
def vocab_size(self) -> int:
return self.processor.vocab_size()

def encode(
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None,
) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))

return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())

@staticmethod
def train(input: str, destination: str, vocab_size=32000) -> None:
model_prefix = os.path.join(destination, "tokenizer")
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
Empty file.
40 changes: 40 additions & 0 deletions merlin/models/torch/tokenizers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path
from typing import Optional, Union

import torch

from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer, require_sentencepiece


class LlamaTokenizer(SentencePieceTokenizer):
def __init__(self, path: Union[str, Path]) -> None:
require_sentencepiece()

from sentencepiece import SentencePieceProcessor

if isinstance(path, Path):
path = str(path)
processor = SentencePieceProcessor(model_file=str(path))

super().__init__(processor=processor)

def endode(
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None,
) -> torch.Tensor:
tokens = super().encode(
string=string,
bos=bos,
eos=eos,
max_length=max_length,
pad=pad,
)
return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())

0 comments on commit 8c3e544

Please sign in to comment.