Skip to content

Commit

Permalink
Show tqdm output, fail on empty results
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 18, 2024
1 parent 345822c commit ac00676
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
25 changes: 19 additions & 6 deletions pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,33 +264,46 @@ async def composite_query(
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
show_progress: Optional[bool] = True,
max_concurrent_requests: Optional[int] = None,
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[CompositeQueryResults]:
aggregator_lock = asyncio.Lock()
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)

# The caller may only want the topK=1 result across all queries,
# but we need to get at least 2 results from each query in order to
# aggregate them correctly. So we'll temporarily set topK to 2 for the
# subqueries, and then we'll take the topK=1 results from the aggregated
# results.
aggregator = QueryResultsAggregator(top_k=top_k)
subquery_topk = top_k if top_k > 2 else 2

target_namespaces = set(namespaces) # dedup namespaces
query_tasks = [
self._query(
vector=vector,
namespace=ns,
top_k=top_k,
top_k=subquery_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
semaphore=semaphore,
**kwargs,
)
for ns in namespaces
for ns in target_namespaces
]

for query_task in asyncio.as_completed(query_tasks):
response = await query_task
async with aggregator_lock:
aggregator.add_results(response)
with tqdm(
total=len(query_tasks), disable=not show_progress, desc="Querying namespaces"
) as pbar:
for query_task in asyncio.as_completed(query_tasks):
response = await query_task
pbar.update(1)
async with aggregator_lock:
aggregator.add_results(response)

final_results = aggregator.get_results()
return final_results
Expand Down
36 changes: 22 additions & 14 deletions pinecone/grpc/query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __repr__(self):
class QueryResultsAggregationEmptyResultsError(Exception):
def __init__(self, namespace: str):
super().__init__(
f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?"
f"Query results for namespace '{namespace}' were empty. Check that you have upserted vectors into this namespace (see describe_index_stats) and that the namespace name is spelled correctly."
)


Expand All @@ -111,7 +111,7 @@ def __init__(self, top_k: int):
self.is_dotproduct = None
self.read = False

def __is_dotproduct_index(self, matches):
def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
Expand All @@ -121,6 +121,20 @@ def __is_dotproduct_index(self, matches):
return False
return True

def _dotproduct_heap_item(self, match, ns):
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
for match in matches:
self.insertion_counter += 1
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
heapq.heappushpop(self.heap, heap_item_fn(match, ns))

def add_results(self, results: QueryResponse):
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
Expand All @@ -132,24 +146,18 @@ def add_results(self, results: QueryResponse):
ns = results.get("namespace")
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if len(matches) == 0:
raise QueryResultsAggregationEmptyResultsError(ns)

if self.is_dotproduct is None:
if len(matches) == 0:
raise QueryResultsAggregationEmptyResultsError(ns)
if len(matches) == 1:
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
self.is_dotproduct = self.__is_dotproduct_index(matches)
self.is_dotproduct = self._is_dotproduct_index(matches)

print("is_dotproduct:", self.is_dotproduct)
if self.is_dotproduct:
raise NotImplementedError("Dotproduct indexes are not yet supported.")
self._process_matches(matches, ns, self._dotproduct_heap_item)
else:
for match in matches:
self.insertion_counter += 1
score = match.get("score")
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
else:
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))
self._process_matches(matches, ns, self._non_dotproduct_heap_item)

def get_results(self) -> CompositeQueryResults:
if self.read:
Expand Down

0 comments on commit ac00676

Please sign in to comment.