Skip to content

Commit

Permalink
Merge pull request #437 from Codium-ai/tr/new_gpt4
Browse files Browse the repository at this point in the history
Introduce support for 'gpt-4-1106-preview' model and dynamic token limit calculation
  • Loading branch information
mrT23 committed Nov 7, 2023
2 parents 5a7c118 + 54f41dd commit 6c82bc9
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
Expand Down
10 changes: 5 additions & 5 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from github import RateLimitExceededException

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return "\n".join(patches_extended)

# if we are over the limit, start pruning
Expand Down Expand Up @@ -179,12 +179,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
new_patch_tokens = token_handler.count_tokens(patch)

# Hard Stop, no more tokens
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue

# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_pr_multi_diffs(git_provider: GitProvider,

patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import yaml
from starlette_context import context

from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

Expand Down Expand Up @@ -341,3 +343,12 @@ def get_user_labels(current_labels):
if user_labels:
get_logger().info(f"Keeping user labels: {user_labels}")
return user_labels


def get_max_tokens(model):
settings = get_settings()
max_tokens_model = MAX_TOKENS[model]
if settings.config.max_model_tokens:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model
3 changes: 2 additions & 1 deletion pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[config]
model="gpt-4"
model="gpt-4" # "gpt-4-1106-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
Expand All @@ -10,6 +10,7 @@ use_repo_settings_file=true
ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
patch_extra_lines = 3
secret_provider="google_cloud_storage"
cli_mode=false
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/tools/pr_similar_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pinecone_datasets import Dataset, DatasetMetadata
from pydantic import BaseModel, Field

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.log import get_logger
Expand Down Expand Up @@ -197,7 +197,7 @@ def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=Fal
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
Expand Down

0 comments on commit 6c82bc9

Please sign in to comment.