diff --git a/elm/base.py b/elm/base.py index a5f1e11..928f579 100644 --- a/elm/base.py +++ b/elm/base.py @@ -44,6 +44,11 @@ class ApiBase(ABC): MODEL_ROLE = "You are a research assistant that answers questions." """High level model role""" + # Optional mappings for weird azure names to tiktoken/openai names + tokenizer_aliases = {'gpt-35-turbo': 'gpt-3.5-turbo', + 'gpt-4-32k': 'gpt-4-32k-0314' + } + def __init__(self, model=None): """ Parameters @@ -338,8 +343,8 @@ def get_embedding(cls, text): return embedding - @staticmethod - def count_tokens(text, model): + @classmethod + def count_tokens(cls, text, model): """Return the number of tokens in a string. Parameters @@ -355,12 +360,7 @@ def count_tokens(text, model): Number of tokens in text """ - # Optional mappings for weird azure names to tiktoken/openai names - tokenizer_aliases = {'gpt-35-turbo': 'gpt-3.5-turbo', - 'gpt-4-32k': 'gpt-4-32k-0314', - 'ewiz-gpt-4': 'gpt-4'} - - token_model = tokenizer_aliases.get(model, model) + token_model = cls.tokenizer_aliases.get(model, model) encoding = tiktoken.encoding_for_model(token_model) return len(encoding.encode(text)) diff --git a/elm/wizard.py b/elm/wizard.py index 963ed21..b325e56 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -399,6 +399,11 @@ class EnergyWizardPostgres(EnergyWizardBase): """ EMBEDDING_MODEL = 'amazon.titan-embed-text-v1' + # Optional mappings for weird azure names to tiktoken/openai names + tokenizer_aliases = {**EnergyWizardBase.tokenizer_aliases, + 'ewiz-gpt-4': 'gpt-4' + } + def __init__(self, db_host, db_port, db_name, db_schema, db_table, cursor=None, boto_client=None, model=None, token_budget=3500):