Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for 'gpt-4-1106-preview' model and dynamic token limit calculation #437

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
Expand Down
10 changes: 5 additions & 5 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from github import RateLimitExceededException

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return "\n".join(patches_extended)

# if we are over the limit, start pruning
Expand Down Expand Up @@ -179,12 +179,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
new_patch_tokens = token_handler.count_tokens(patch)

# Hard Stop, no more tokens
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue

# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_pr_multi_diffs(git_provider: GitProvider,

patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
Expand Down
10 changes: 10 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import yaml
from starlette_context import context

from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

Expand Down Expand Up @@ -341,3 +343,11 @@ def get_user_labels(current_labels):
if user_labels:
get_logger().info(f"Keeping user labels: {user_labels}")
return user_labels

def get_max_tokens(model):
max_tokens_model = MAX_TOKENS[model]
if get_settings().config.max_model_tokens:
if max_tokens_model > get_settings().config.max_model_tokens:
max_tokens_model = get_settings().config.max_model_tokens
Copy link
Collaborator

@hussam789 hussam789 Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, Use the min function:

max_tokens_model = min(max_tokens_model, get_settings().config.max_model_tokens)

# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Refactor the 'get_max_tokens' function to use a local variable for settings to avoid multiple calls to 'get_settings()' which may be inefficient if it involves I/O operations or complex computations.

Suggested change
def get_max_tokens(model):
max_tokens_model = MAX_TOKENS[model]
if get_settings().config.max_model_tokens:
if max_tokens_model > get_settings().config.max_model_tokens:
max_tokens_model = get_settings().config.max_model_tokens
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model
def get_max_tokens(model):
settings = get_settings()
max_tokens_model = MAX_TOKENS[model]
if settings.config.max_model_tokens and max_tokens_model > settings.config.max_model_tokens:
max_tokens_model = settings.config.max_model_tokens
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Add a guard clause to 'get_max_tokens' to handle the case where the model is not found in the MAX_TOKENS dictionary, which would currently raise a KeyError.

Suggested change
def get_max_tokens(model):
max_tokens_model = MAX_TOKENS[model]
if get_settings().config.max_model_tokens:
if max_tokens_model > get_settings().config.max_model_tokens:
max_tokens_model = get_settings().config.max_model_tokens
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model
def get_max_tokens(model):
if model not in MAX_TOKENS:
raise ValueError(f"Model '{model}' not found in MAX_TOKENS dictionary.")
max_tokens_model = MAX_TOKENS[model]
settings = get_settings()
if settings.config.max_model_tokens and max_tokens_model > settings.config.max_model_tokens:
max_tokens_model = settings.config.max_model_tokens
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model

3 changes: 2 additions & 1 deletion pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[config]
model="gpt-4"
model="gpt-4" # "gpt-4-1106-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
Expand All @@ -10,6 +10,7 @@ use_repo_settings_file=true
ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
max_model_tokens = 32000 # even if a model supports more tokens, quality may degrade. Hence, enable to limit the number of tokens.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The 'max_model_tokens' setting should be documented to explain its purpose and effect on the system, especially since it overrides the model's default token limit.

Suggested change
max_model_tokens = 32000 # even if a model supports more tokens, quality may degrade. Hence, enable to limit the number of tokens.
# Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
# This can be useful to ensure consistent performance and to prevent issues related to large token counts.
max_model_tokens = 32000

patch_extra_lines = 3
secret_provider="google_cloud_storage"
cli_mode=false
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/tools/pr_similar_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pinecone_datasets import Dataset, DatasetMetadata
from pydantic import BaseModel, Field

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -197,7 +197,7 @@ def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=Fal
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
Expand Down
Loading