Skip to content

Commit

Permalink
feat: add NestedDictInput filter and non-vector search for AstraVecto…
Browse files Browse the repository at this point in the history
…rStoreComponent (#4564)

* NestedDictInput filter and non-vector search for AstraVectorStoreComponent

* [autofix.ci] apply automated fixes

* addressing Ruff linting

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Hare <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2024
1 parent 0461baf commit 44b0531
Showing 1 changed file with 63 additions and 28 deletions.
91 changes: 63 additions & 28 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import orjson
from astrapy.admin import parse_api_endpoint
from langchain_astradb import AstraDBVectorStore

from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers import docs_to_data
from langflow.inputs import DictInput, FloatInput, MessageTextInput
from langflow.inputs import DictInput, FloatInput, MessageTextInput, NestedDictInput
from langflow.io import (
BoolInput,
DataInput,
Expand All @@ -26,6 +27,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name = "AstraDB"
icon: str = "AstraDB"

_cached_vector_store: AstraDBVectorStore | None = None

VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
Expand Down Expand Up @@ -201,11 +204,17 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
value=0,
advanced=True,
),
DictInput(
name="search_filter",
NestedDictInput(
name="advanced_search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
),
DictInput(
name="search_filter",
display_name="[DEPRECATED] Search Metadata Filter",
info="Deprecated: use advanced_search_filter. Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
]
Expand Down Expand Up @@ -482,43 +491,69 @@ def _map_search_type(self) -> str:
return "similarity"

def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
query = self.search_input if isinstance(self.search_input, str) and self.search_input.strip() else None
search_filter = (
{k: v for k, v in self.search_filter.items() if k and v and k.strip()} if self.search_filter else None
)

if query:
args = {
"query": query,
"search_type": self._map_search_type(),
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}
elif self.advanced_search_filter or search_filter:
args = {
"n": self.number_of_results,
}
else:
return {}

filter_arg = self.advanced_search_filter or {}

if search_filter:
self.log(self.log(f"`search_filter` is deprecated. Use `advanced_search_filter`. Cleaned: {search_filter}"))
filter_arg.update(search_filter)

if filter_arg:
args["filter"] = filter_arg

if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args

def search_documents(self, vector_store=None) -> list[Data]:
if not vector_store:
vector_store = self.build_vector_store()
vector_store = vector_store or self.build_vector_store()

self.log(f"Search input: {self.search_input}")
self.log(f"Search type: {self.search_type}")
self.log(f"Number of results: {self.number_of_results}")

if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
try:
search_type = self._map_search_type()
search_args = self._build_search_args()
try:
search_args = self._build_search_args()
except Exception as e:
msg = f"Error in AstraDBVectorStore._build_search_args: {e}"
raise ValueError(msg) from e

docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
except Exception as e:
msg = f"Error performing search in AstraDBVectorStore: {e}"
raise ValueError(msg) from e
if not search_args:
self.log("No search input or filters provided. Skipping search.")
return []

docs = []
search_method = "search" if "query" in search_args else "metadata_search"

try:
self.log(f"Calling vector_store.{search_method} with args: {search_args}")
docs = getattr(vector_store, search_method)(**search_args)
except Exception as e:
msg = f"Error performing {search_method} in AstraDBVectorStore: {e}"
raise ValueError(msg) from e

self.log(f"Retrieved documents: {len(docs)}")
self.log(f"Retrieved documents: {len(docs)}")

data = docs_to_data(docs)
self.log(f"Converted documents to data: {len(data)}")
self.status = data
return data
self.log("No search input provided. Skipping search.")
return []
data = docs_to_data(docs)
self.log(f"Converted documents to data: {len(data)}")
self.status = data
return data

def get_retriever_kwargs(self):
search_args = self._build_search_args()
Expand Down

0 comments on commit 44b0531

Please sign in to comment.