Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tqdm and trange imports for progress tracking when indexing #59 #68

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions benchmarks/retrieval/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import time
from tqdm import tqdm, trange

import configargparse
from dotenv import load_dotenv
Expand Down Expand Up @@ -57,20 +58,20 @@ def main():
golden_docs = [] # List of ir_measures.Qrel objects
retrieved_docs = [] # List of ir_measures.ScoredDoc objects

for question_idx, item in enumerate(benchmark):
for question_idx, item in tqdm(enumerate(benchmark)):
print(f"Processing question {question_idx}...")

query_id = str(question_idx) # Solely needed for ir_measures library.

for golden_filepath in item[args.gold_field]:
for golden_filepath in tqdm(item[args.gold_field]):
# All the file paths in the golden answer are equally relevant for the query (i.e. the order is irrelevant),
# so we set relevance=1 for all of them.
golden_docs.append(Qrel(query_id=query_id, doc_id=golden_filepath, relevance=1))

# Make a retrieval call for the current question.
retrieved = retriever.invoke(item[args.question_field])
item["retrieved"] = []
for doc_idx, doc in enumerate(retrieved):
for doc_idx, doc in tqdm(enumerate(retrieved)):
# The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
# the retrieved documents. The key of the score varies depending on the underlying retriever. If there's no
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
Expand Down Expand Up @@ -99,7 +100,7 @@ def main():
with open(output_file, "w") as f:
json.dump(out_data, f, indent=4)

for key in sorted(results.keys()):
for key in tqdm(sorted(results.keys())):
print(f"{key}: {results[key]}")
print(f"Predictions and metrics saved to {output_file}")

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/retrieval/retrieve_kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import sage.config
from sage.retriever import build_retriever_from_args
from tqdm import tqdm, trange

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
Expand Down Expand Up @@ -51,7 +52,7 @@ def main():
benchmark = [row for row in benchmark]

outputs = []
for question_idx, item in enumerate(benchmark):
for question_idx, item in tqdm(enumerate(benchmark)):
print(f"Processing question {question_idx}...")

retrieved = retriever.invoke(item["question"])
Expand Down
6 changes: 6 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions sage/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def main():

validator = sage_config.add_all_args(parser)
args = parser.parse_args()
validator(args)

for validator in tqdm(arg_validators):
validator(args)

rag_chain = build_rag_chain(args)

Expand All @@ -85,25 +87,25 @@ def source_md(file_path: str, url: str) -> str:
async def _predict(message, history):
"""Performs one RAG operation."""
history_langchain_format = []
for human, ai in history:
for human, ai in tqdm(history):
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))

query_rewrite = ""
response = ""
async for event in rag_chain.astream_events(
async for event in tqdm(rag_chain.astream_events)(
{
"input": message,
"chat_history": history_langchain_format,
},
version="v1",
):
if event["name"] == "retrieve_documents" and "output" in event["data"]:
sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]]
sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in tqdm((event["data"]["output"]))]
# Deduplicate while preserving the order.
sources = list(dict.fromkeys(sources))
response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n"
response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in tqdm(sources)]) + "\n## Response:\n"

elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"].content
Expand Down
15 changes: 8 additions & 7 deletions sage/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from semchunk import chunk as chunk_via_semchunk
from tree_sitter import Node
from tree_sitter_language_pack import get_parser
from tqdm import tqdm, trange

from sage.constants import TEXT_FIELD

Expand Down Expand Up @@ -130,17 +131,17 @@ def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> Lis
return self.text_chunker.chunk(file_content[node.start_byte : node.end_byte], file_metadata)

chunks = []
for child in node.children:
for child in tqdm(node.children):
chunks.extend(self._chunk_node(child, file_content, file_metadata))

for chunk in chunks:
for chunk in tqdm(chunks):
# This should always be true. Otherwise there must be a bug in the code.
assert chunk.num_tokens <= self.max_tokens

# Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically
# small chunks that end up being undeservedly preferred by the retriever.
merged_chunks = []
for chunk in chunks:
for chunk in tqdm(chunks):
if not merged_chunks:
merged_chunks.append(chunk)
elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50:
Expand All @@ -160,7 +161,7 @@ def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> Lis
merged_chunks.append(chunk)
chunks = merged_chunks

for chunk in merged_chunks:
for chunk in tqdm(merged_chunks):
# This should always be true. Otherwise there's a bug worth investigating.
assert chunk.num_tokens <= self.max_tokens

Expand Down Expand Up @@ -221,7 +222,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
return []

file_chunks = self._chunk_node(tree.root_node, file_content, file_metadata)
for chunk in file_chunks:
for chunk in tqdm(file_chunks):
# Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be
# a bug in the code.
assert (
Expand Down Expand Up @@ -250,7 +251,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:

file_chunks = []
start = 0
for text_chunk in text_chunks:
for text_chunk in tqdm(text_chunks):
# This assertion should always be true. Otherwise there's a bug worth finding.
assert self.count_tokens(text_chunk) <= self.max_tokens - extra_tokens

Expand Down Expand Up @@ -289,7 +290,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
tmp_metadata = {"file_path": filename.replace(".ipynb", ".py")}
chunks = self.code_chunker.chunk(python_code, tmp_metadata)

for chunk in chunks:
for chunk in tqdm(chunks):
# Update filenames back to .ipynb
chunk.metadata["file_path"] = filename
return chunks
Expand Down
1 change: 1 addition & 0 deletions sage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from configargparse import ArgumentParser

from sage.reranker import RerankerProvider
from tqdm import tqdm, trange

# Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
GEMINI_MAX_TOKENS_PER_CHUNK = 2048
Expand Down
31 changes: 11 additions & 20 deletions sage/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import requests
from git import GitCommandError, Repo
from tqdm import tqdm, trange


class DataManager:
Expand Down Expand Up @@ -130,7 +131,7 @@ def _parse_filter_file(self, file_path: str) -> bool:
lines = f.readlines()

parsed_data = {"ext": [], "file": [], "dir": []}
for line in lines:
for line in tqdm(lines):
if line.startswith("#"):
# This is a comment line.
continue
Expand All @@ -149,7 +150,7 @@ def _should_include(self, file_path: str) -> bool:
return False

# Exclude hidden files and directories.
if any(part.startswith(".") for part in file_path.split(os.path.sep)):
if any(part.startswith(".") for part in tqdm(file_path.split(os.path.sep))):
return False

if not self.inclusions and not self.exclusions:
Expand All @@ -165,13 +166,13 @@ def _should_include(self, file_path: str) -> bool:
return (
extension in self.inclusions.get("ext", [])
or file_name in self.inclusions.get("file", [])
or any(d in dirs for d in self.inclusions.get("dir", []))
or any(d in dirs for d in tqdm(self.inclusions.get("dir", [])))
)
elif self.exclusions:
return (
extension not in self.exclusions.get("ext", [])
and file_name not in self.exclusions.get("file", [])
and all(d not in dirs for d in self.exclusions.get("dir", []))
and all(d not in dirs for d in tqdm(self.exclusions.get("dir", [])))
)
return True

Expand All @@ -193,30 +194,20 @@ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, No
os.remove(excluded_log_file)
logging.info("Logging excluded files at %s", excluded_log_file)

for root, _, files in os.walk(self.local_path):
file_paths = [os.path.join(root, file) for file in files]
included_file_paths = [f for f in file_paths if self._should_include(f)]
for root, _, files in tqdm(os.walk(self.local_path)):
file_paths = [os.path.join(root, file) for file in tqdm(files)]
included_file_paths = [f for f in tqdm(file_paths) if self._should_include(f)]

with open(included_log_file, "a") as f:
for path in included_file_paths:
for path in tqdm(included_file_paths):
f.write(path + "\n")

excluded_file_paths = set(file_paths).difference(set(included_file_paths))
with open(excluded_log_file, "a") as f:
for path in excluded_file_paths:
for path in tqdm(excluded_file_paths):
f.write(path + "\n")

for file_path in included_file_paths:
relative_file_path = file_path[len(self.local_dir) + 1 :]
metadata = {
"file_path": relative_file_path,
"url": self.url_for_file(relative_file_path),
}

if not get_content:
yield metadata
continue

for file_path in tqdm(included_file_paths):
with open(file_path, "r") as f:
try:
contents = f.read()
Expand Down
Loading