Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for 'gpt-4-1106-preview' model and dynamic token limit calculation #437

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
Expand Down
10 changes: 5 additions & 5 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from github import RateLimitExceededException

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return "\n".join(patches_extended)

# if we are over the limit, start pruning
Expand Down Expand Up @@ -179,12 +179,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
new_patch_tokens = token_handler.count_tokens(patch)

# Hard Stop, no more tokens
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue

# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_pr_multi_diffs(git_provider: GitProvider,

patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import yaml
from starlette_context import context

from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

Expand Down Expand Up @@ -341,3 +343,12 @@ def get_user_labels(current_labels):
if user_labels:
get_logger().info(f"Keeping user labels: {user_labels}")
return user_labels


def get_max_tokens(model):
settings = get_settings()
max_tokens_model = MAX_TOKENS[model]
if settings.config.max_model_tokens:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model
3 changes: 2 additions & 1 deletion pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[config]
model="gpt-4"
model="gpt-4" # "gpt-4-1106-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
Expand All @@ -10,6 +10,7 @@ use_repo_settings_file=true
ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
patch_extra_lines = 3
secret_provider="google_cloud_storage"
cli_mode=false
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/tools/pr_similar_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pinecone_datasets import Dataset, DatasetMetadata
from pydantic import BaseModel, Field

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -197,7 +197,7 @@ def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=Fal
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
Expand Down
Loading