diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 56511cd0e..5a2533635 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -8,6 +8,7 @@ 'gpt-4': 8000, 'gpt-4-0613': 8000, 'gpt-4-32k': 32000, + 'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens 'claude-instant-1': 100000, 'claude-2': 100000, 'command-nightly': 4096, diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 63bd02eb8..269093cb7 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -7,11 +7,11 @@ from github import RateLimitExceededException -from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.file_filter import filter_ignored from pr_agent.algo.token_handler import TokenHandler, get_token_encoder +from pr_agent.algo.utils import get_max_tokens from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider from pr_agent.log import get_logger @@ -64,7 +64,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES) # if we are under the limit, return the full diff - if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]: + if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): return "\n".join(patches_extended) # if we are over the limit, start pruning @@ -179,12 +179,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo new_patch_tokens = token_handler.count_tokens(patch) # Hard Stop, no more tokens - if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: + if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.") continue # If the patch is too large, just show the file name - if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: # Current logic is to skip the patch if it's too large # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # until we meet the requirements @@ -403,7 +403,7 @@ def get_pr_multi_diffs(git_provider: GitProvider, patch = convert_to_hunks_with_lines_numbers(patch, file) new_patch_tokens = token_handler.count_tokens(patch) - if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD): + if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD): final_diff = "\n".join(patches) final_diff_list.append(final_diff) patches = [] diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 440b6615a..ecaf4aa89 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -9,6 +9,8 @@ import yaml from starlette_context import context + +from pr_agent.algo import MAX_TOKENS from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -341,3 +343,12 @@ def get_user_labels(current_labels): if user_labels: get_logger().info(f"Keeping user labels: {user_labels}") return user_labels + + +def get_max_tokens(model): + settings = get_settings() + max_tokens_model = MAX_TOKENS[model] + if settings.config.max_model_tokens: + max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) + # get_logger().debug(f"limiting max tokens to {max_tokens_model}") + return max_tokens_model diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 8445f01db..b8020eeb9 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,5 +1,5 @@ [config] -model="gpt-4" +model="gpt-4" # "gpt-4-1106-preview" fallback_models=["gpt-3.5-turbo-16k"] git_provider="github" publish_output=true @@ -10,6 +10,7 @@ use_repo_settings_file=true ai_timeout=180 max_description_tokens = 500 max_commits_tokens = 500 +max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. patch_extra_lines = 3 secret_provider="google_cloud_storage" cli_mode=false diff --git a/pr_agent/tools/pr_similar_issue.py b/pr_agent/tools/pr_similar_issue.py index c3a0793bc..c717b59fc 100644 --- a/pr_agent/tools/pr_similar_issue.py +++ b/pr_agent/tools/pr_similar_issue.py @@ -8,8 +8,8 @@ from pinecone_datasets import Dataset, DatasetMetadata from pydantic import BaseModel, Field -from pr_agent.algo import MAX_TOKENS from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import get_max_tokens from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import get_logger @@ -197,7 +197,7 @@ def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=Fal username = issue.user.login created_at = str(issue.created_at) if len(issue_str) < 8000 or \ - self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first + self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first issue_record = Record( id=issue_key + "." + "issue", text=issue_str,