diff --git a/pinecone/grpc/index_grpc_asyncio.py b/pinecone/grpc/index_grpc_asyncio.py index 593fd3c3..0f7f906e 100644 --- a/pinecone/grpc/index_grpc_asyncio.py +++ b/pinecone/grpc/index_grpc_asyncio.py @@ -264,19 +264,28 @@ async def composite_query( include_values: Optional[bool] = None, include_metadata: Optional[bool] = None, sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + show_progress: Optional[bool] = True, max_concurrent_requests: Optional[int] = None, semaphore: Optional[asyncio.Semaphore] = None, **kwargs, ) -> Awaitable[CompositeQueryResults]: aggregator_lock = asyncio.Lock() semaphore = self._get_semaphore(max_concurrent_requests, semaphore) + + # The caller may only want the topK=1 result across all queries, + # but we need to get at least 2 results from each query in order to + # aggregate them correctly. So we'll temporarily set topK to 2 for the + # subqueries, and then we'll take the topK=1 results from the aggregated + # results. aggregator = QueryResultsAggregator(top_k=top_k) + subquery_topk = top_k if top_k > 2 else 2 + target_namespaces = set(namespaces) # dedup namespaces query_tasks = [ self._query( vector=vector, namespace=ns, - top_k=top_k, + top_k=subquery_topk, filter=filter, include_values=include_values, include_metadata=include_metadata, @@ -284,13 +293,17 @@ async def composite_query( semaphore=semaphore, **kwargs, ) - for ns in namespaces + for ns in target_namespaces ] - for query_task in asyncio.as_completed(query_tasks): - response = await query_task - async with aggregator_lock: - aggregator.add_results(response) + with tqdm( + total=len(query_tasks), disable=not show_progress, desc="Querying namespaces" + ) as pbar: + for query_task in asyncio.as_completed(query_tasks): + response = await query_task + pbar.update(1) + async with aggregator_lock: + aggregator.add_results(response) final_results = aggregator.get_results() return final_results diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py index 238cff17..345f006d 100644 --- a/pinecone/grpc/query_results_aggregator.py +++ b/pinecone/grpc/query_results_aggregator.py @@ -84,7 +84,7 @@ def __repr__(self): class QueryResultsAggregationEmptyResultsError(Exception): def __init__(self, namespace: str): super().__init__( - f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?" + f"Query results for namespace '{namespace}' were empty. Check that you have upserted vectors into this namespace (see describe_index_stats) and that the namespace name is spelled correctly." ) @@ -111,7 +111,7 @@ def __init__(self, top_k: int): self.is_dotproduct = None self.read = False - def __is_dotproduct_index(self, matches): + def _is_dotproduct_index(self, matches): # The interpretation of the score depends on the similar metric used. # Unlike other index types, in indexes configured for dotproduct, # a higher score is better. We have to infer this is the case by inspecting @@ -121,6 +121,20 @@ def __is_dotproduct_index(self, matches): return False return True + def _dotproduct_heap_item(self, match, ns): + return (match.get("score"), -self.insertion_counter, match, ns) + + def _non_dotproduct_heap_item(self, match, ns): + return (-match.get("score"), -self.insertion_counter, match, ns) + + def _process_matches(self, matches, ns, heap_item_fn): + for match in matches: + self.insertion_counter += 1 + if len(self.heap) < self.top_k: + heapq.heappush(self.heap, heap_item_fn(match, ns)) + else: + heapq.heappushpop(self.heap, heap_item_fn(match, ns)) + def add_results(self, results: QueryResponse): if self.read: # This is mainly just to sanity check in test cases which get quite confusing @@ -132,24 +146,18 @@ def add_results(self, results: QueryResponse): ns = results.get("namespace") self.usage_read_units += results.get("usage", {}).get("readUnits", 0) + if len(matches) == 0: + raise QueryResultsAggregationEmptyResultsError(ns) + if self.is_dotproduct is None: - if len(matches) == 0: - raise QueryResultsAggregationEmptyResultsError(ns) if len(matches) == 1: raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches)) - self.is_dotproduct = self.__is_dotproduct_index(matches) + self.is_dotproduct = self._is_dotproduct_index(matches) - print("is_dotproduct:", self.is_dotproduct) if self.is_dotproduct: - raise NotImplementedError("Dotproduct indexes are not yet supported.") + self._process_matches(matches, ns, self._dotproduct_heap_item) else: - for match in matches: - self.insertion_counter += 1 - score = match.get("score") - if len(self.heap) < self.top_k: - heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns)) - else: - heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns)) + self._process_matches(matches, ns, self._non_dotproduct_heap_item) def get_results(self) -> CompositeQueryResults: if self.read: