Skip to content

Commit

Permalink
feat(agents-api): Tweak queries for search (#685)
Browse files Browse the repository at this point in the history
- **feat(agents-api): Tweak the proximity indices**
- **feat(agents-api): Tweak queries for search**

<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Enhances search functionality by adding NLP-based query generation,
adjusting search parameters, and updating database indices in the
`agents-api`.
> 
>   - **NLP Module**:
> - Added `nlp.py` for keyword extraction and query building using
spaCy.
> - Functions include `extract_keywords()`, `find_proximity_groups()`,
and `text_to_custom_query()`.
>   - **Search Functionality**:
> - Updated `search_docs_by_embedding()` in
`search_docs_by_embedding.py` to adjust `confidence` to 0.5 and `ef` to
32.
> - Modified `search_docs_by_text()` in `search_docs_by_text.py` to use
`paragraph_to_custom_queries()` for query generation.
>   - **Database Indices**:
> - Migration `migrate_1729114011_tweak_proximity_indices.py` updates
LSH and FTS indices for better proximity handling.
>   - **Dependencies**:
> - Added `spacy`, `en-core-web-sm`, and `msgpack` to `pyproject.toml`.
>     - Adjusted `numpy` version constraint in `pyproject.toml`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for 3018c6e. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->

---------

Signed-off-by: Diwank Singh Tomer <[email protected]>
Co-authored-by: creatorrr <[email protected]>
  • Loading branch information
creatorrr and creatorrr authored Oct 17, 2024
1 parent eb13894 commit aa44bfd
Show file tree
Hide file tree
Showing 6 changed files with 1,241 additions and 153 deletions.
216 changes: 216 additions & 0 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import re
from collections import Counter, defaultdict

import spacy

# Load spaCy English model
spacy.prefer_gpu()
nlp = spacy.load("en_core_web_sm")


def extract_keywords(text: str, top_n: int = 10) -> list[str]:
"""
Extracts significant keywords and phrases from the text.
Args:
text (str): The input text to process.
top_n (int): Number of top keywords to extract based on frequency.
Returns:
List[str]: A list of extracted keywords/phrases.
"""
doc = nlp(text)

# Extract named entities
entities = [
ent.text.strip()
for ent in doc.ents
if ent.label_
not in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]
]

# Extract nouns and proper nouns
nouns = [
chunk.text.strip().lower()
for chunk in doc.noun_chunks
if not chunk.root.is_stop
]

# Combine entities and nouns
combined = entities + nouns

# Normalize and count frequency
normalized = [re.sub(r"\s+", " ", kw).strip().lower() for kw in combined]
freq = Counter(normalized)

# Get top_n keywords
keywords = [item for item, count in freq.most_common(top_n)]

return keywords


def find_keyword_positions(doc, keyword: str) -> list[int]:
"""
Finds all start indices of the keyword in the tokenized doc.
Args:
doc (spacy.tokens.Doc): The tokenized document.
keyword (str): The keyword or phrase to search for.
Returns:
List[int]: List of starting token indices where the keyword appears.
"""
keyword_tokens = keyword.split()
n = len(keyword_tokens)
positions = []
for i in range(len(doc) - n + 1):
window = doc[i : i + n]
window_text = " ".join([token.text.lower() for token in window])
if window_text == keyword:
positions.append(i)
return positions


def find_proximity_groups(
text: str, keywords: list[str], n: int = 10
) -> list[set[str]]:
"""
Groups keywords that appear within n words of each other.
Args:
text (str): The input text.
keywords (List[str]): List of keywords to consider.
n (int): The proximity window in words.
Returns:
List[Set[str]]: List of sets, each containing keywords that are proximate.
"""
doc = nlp(text.lower())
keyword_positions = defaultdict(list)

for kw in keywords:
positions = find_keyword_positions(doc, kw)
keyword_positions[kw].extend(positions)

# Initialize Union-Find structure
parent = {}

def find(u):
while parent[u] != u:
parent[u] = parent[parent[u]]
u = parent[u]
return u

def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
parent[v_root] = u_root

# Initialize each keyword as its own parent
for kw in keywords:
parent[kw] = kw

# Compare all pairs of keywords
for i in range(len(keywords)):
for j in range(i + 1, len(keywords)):
kw1 = keywords[i]
kw2 = keywords[j]
positions1 = keyword_positions[kw1]
positions2 = keyword_positions[kw2]
# Check if any positions are within n words
for pos1 in positions1:
for pos2 in positions2:
distance = abs(pos1 - pos2)
if distance <= n:
union(kw1, kw2)
break
else:
continue
break

# Group keywords by their root parent
groups = defaultdict(set)
for kw in keywords:
root = find(kw)
groups[root].add(kw)

# Convert to list of sets
group_list = list(groups.values())

return group_list


def build_query(groups: list[set[str]], keywords: list[str], n: int = 10) -> str:
"""
Builds a query string using the custom query language.
Args:
groups (List[Set[str]]): List of keyword groups.
keywords (List[str]): Original list of keywords.
n (int): The proximity window for NEAR.
Returns:
str: The constructed query string.
"""
grouped_keywords = set()
clauses = []

for group in groups:
if len(group) == 1:
clauses.append(f'"{list(group)[0]}"')
else:
sorted_group = sorted(
group, key=lambda x: -len(x)
) # Sort by length to prioritize phrases
escaped_keywords = [f'"{kw}"' for kw in sorted_group]
near_clause = f"NEAR/{n}(" + " ".join(escaped_keywords) + ")"
clauses.append(near_clause)
grouped_keywords.update(group)

# Identify keywords not in any group (if any)
remaining = set(keywords) - grouped_keywords
for kw in remaining:
clauses.append(f'"{kw}"')

# Combine all clauses with OR
query = " OR ".join(clauses)

return query


def text_to_custom_query(text: str, top_n: int = 10, proximity_n: int = 10) -> str:
"""
Converts arbitrary text to the custom query language.
Args:
text (str): The input text to convert.
top_n (int): Number of top keywords to extract.
proximity_n (int): The proximity window for NEAR/n.
Returns:
str: The custom query string.
"""
keywords = extract_keywords(text, top_n)
if not keywords:
return ""
groups = find_proximity_groups(text, keywords, proximity_n)
query = build_query(groups, keywords, proximity_n)
return query


def paragraph_to_custom_queries(paragraph: str) -> list[str]:
"""
Converts a paragraph to a list of custom query strings.
Args:
paragraph (str): The input paragraph to convert.
Returns:
List[str]: The list of custom query strings.
"""

queries = [text_to_custom_query(sentence.text) for sentence in nlp(paragraph).sents]

return queries
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def search_docs_by_embedding(
owners: list[tuple[Literal["user", "agent"], UUID]],
query_embedding: list[float],
k: int = 3,
confidence: float = 0.7,
ef: int = 128,
confidence: float = 0.5,
ef: int = 32,
mmr_lambda: float = 0.25,
embedding_size: int = 1024,
) -> tuple[list[str], dict]:
Expand Down
9 changes: 5 additions & 4 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import json
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand All @@ -10,6 +9,7 @@
from pydantic import ValidationError

from ...autogen.openapi_model import DocReference
from ...common.nlp import paragraph_to_custom_queries
from ..utils import (
cozo_query,
partialclass,
Expand Down Expand Up @@ -64,7 +64,7 @@ def search_docs_by_text(

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
query = f"NEAR/3({json.dumps(query)})"
fts_queries = paragraph_to_custom_queries(query)

# Construct the datalog query for searching document snippets
search_query = f"""
Expand Down Expand Up @@ -112,11 +112,12 @@ def search_docs_by_text(
index,
content
|
query: $query,
query: query,
k: {k},
score_kind: 'tf_idf',
bind_score: score,
}},
query in $fts_queries,
distance = -score,
snippet_data = [index, content]
Expand Down Expand Up @@ -183,5 +184,5 @@ def search_docs_by_text(

return (
queries,
{"owners": owners, "query": query},
{"owners": owners, "query": query, "fts_queries": fts_queries},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# /usr/bin/env python3

MIGRATION_ID = "tweak_proximity_indices"
CREATED_AT = 1729114011.022733


def run(client, *queries):
joiner = "}\n\n{"

query = joiner.join(queries)
query = f"{{\n{query}\n}}"
client.run(query)


drop_snippets_lsh_index = dict(
up="""
::lsh drop snippets:lsh
""",
down="""
::lsh create snippets:lsh {
extractor: content,
tokenizer: Simple,
filters: [Stopwords('en')],
n_perm: 200,
target_threshold: 0.9,
n_gram: 3,
false_positive_weight: 1.0,
false_negative_weight: 1.0,
}
""",
)

snippets_lsh_index = dict(
up="""
::lsh create snippets:lsh {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')],
n_perm: 200,
target_threshold: 0.5,
n_gram: 2,
false_positive_weight: 1.0,
false_negative_weight: 1.0,
}
""",
down="""
::lsh drop snippets:lsh
""",
)

# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
drop_snippets_fts_index = dict(
down="""
::fts create snippets:fts {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, Stemmer('english'), Stopwords('en')],
}
""",
up="""
::fts drop snippets:fts
""",
)

# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
snippets_fts_index = dict(
up="""
::fts create snippets:fts {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')],
}
""",
down="""
::fts drop snippets:fts
""",
)

queries_to_run = [
drop_snippets_lsh_index,
drop_snippets_fts_index,
snippets_lsh_index,
snippets_fts_index,
]


def up(client):
run(client, *[q["up"] for q in queries_to_run])


def down(client):
run(client, *[q["down"] for q in reversed(queries_to_run)])
Loading

0 comments on commit aa44bfd

Please sign in to comment.