From 741ddab997ca0ad2860ba29884eaf568fccc3c37 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Thu, 20 Jun 2024 21:14:27 +0300 Subject: [PATCH] TRTLLM: Being more generic with type handling to solve #113 (#114) --- lmformatenforcer/integrations/trtllm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lmformatenforcer/integrations/trtllm.py b/lmformatenforcer/integrations/trtllm.py index 27be595..3847b44 100644 --- a/lmformatenforcer/integrations/trtllm.py +++ b/lmformatenforcer/integrations/trtllm.py @@ -36,11 +36,16 @@ def __call__(self, step: int, batch_input_ids: List[List[int]], logits: torch.Te def _build_regular_tokens_list(tokenizer) -> List[Tuple[int, str, bool]]: + # There are many classes that can be passed here, this logic should work on all of them. + if hasattr(tokenizer, 'get_tokenizer'): + tokenizer = tokenizer.get_tokenizer() + if hasattr(tokenizer, 'tokenizer'): + tokenizer = tokenizer.tokenizer token_0 = [tokenizer.encode("0")[-1]] regular_tokens = [] - vocab_size = tokenizer.tokenizer.vocab_size + vocab_size = tokenizer.vocab_size for token_idx in range(vocab_size): - if token_idx in tokenizer.tokenizer.all_special_ids: + if token_idx in tokenizer.all_special_ids: continue # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. tensor_after_0 = torch.tensor(token_0 + [token_idx], dtype=torch.long)