diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b1a2340450f..e79a4bb077b 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -513,7 +513,7 @@ def advance(self, next_token_id, fsm_grammar_state): def _advance(next_token_id, fsm_grammar_state, fsm): if fsm_grammar_state == -1: return fsm_grammar_state - return fsm.next_state(fsm_grammar_state, next_token_id) + return fsm.get_next_state(fsm_grammar_state, next_token_id) # TODO: move grammar compilation into the router @staticmethod @@ -588,7 +588,7 @@ def __call__( fsm = self.fsms[i] if fsm_grammar_states[i] == -1 or fsm is None: continue - allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens mask[i, allowed_tokens] = 0 logits[i] += mask[i] return logits