From 73751ce3655be2f17ceb78e799ac274294064950 Mon Sep 17 00:00:00 2001 From: MrPresent-Han Date: Sat, 6 May 2023 20:14:13 +0800 Subject: [PATCH] support iterator for query and search Signed-off-by: MrPresent-Han --- examples/iterator.py | 145 +++++++++++++++++++++++ pymilvus/client/constants.py | 1 + pymilvus/client/prepare.py | 6 +- pymilvus/orm/collection.py | 29 ++++- pymilvus/orm/constants.py | 12 +- pymilvus/orm/iterator.py | 218 +++++++++++++++++++++++++++++++++++ 6 files changed, 404 insertions(+), 7 deletions(-) create mode 100644 examples/iterator.py create mode 100644 pymilvus/orm/iterator.py diff --git a/examples/iterator.py b/examples/iterator.py new file mode 100644 index 000000000..f763faf33 --- /dev/null +++ b/examples/iterator.py @@ -0,0 +1,145 @@ +import numpy as np +import random +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +HOST = "localhost" +PORT = "19530" +COLLECTION_NAME = "test_iterator" +USER_ID = "id" +MAX_LENGTH = 65535 +AGE = "age" +DEPOSIT = "deposit" +PICTURE = "picture" +CONSISTENCY_LEVEL = "Eventually" +LIMIT = 5 +NUM_ENTITIES = 1000 +DIM = 8 +CLEAR_EXIST = False + + +def re_create_collection(): + if utility.has_collection(COLLECTION_NAME) and CLEAR_EXIST: + utility.drop_collection(COLLECTION_NAME) + print(f"dropped existed collection{COLLECTION_NAME}") + + fields = [ + FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True, + auto_id=False, max_length=MAX_LENGTH), + FieldSchema(name=AGE, dtype=DataType.INT64), + FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE), + FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM) + ] + + schema = CollectionSchema(fields) + print(f"Create collection {COLLECTION_NAME}") + collection = Collection(COLLECTION_NAME, schema, consistency_level=CONSISTENCY_LEVEL) + return collection + + +def insert_data(collection): + rng = np.random.default_rng(seed=19530) + batch_count = 5 + for i in range(batch_count): + entities = [ + [str(random.randint(NUM_ENTITIES * i, NUM_ENTITIES * (i + 1))) for ni in range(NUM_ENTITIES)], + [int(ni % 100) for ni in range(NUM_ENTITIES)], + [float(ni) for ni in range(NUM_ENTITIES)], + rng.random((NUM_ENTITIES, DIM)), + ] + collection.insert(entities) + collection.flush() + print(f"Finish insert batch{i}, number of entities in Milvus: {collection.num_entities}") + +def prepare_index(collection): + index = { + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, + } + + collection.create_index(PICTURE, index) + print("Finish Creating index IVF_FLAT") + collection.load() + print("Finish Loading index IVF_FLAT") + + +def prepare_data(collection): + insert_data(collection) + prepare_index(collection) + return collection + + +def query_iterate_collection_no_offset(collection): + expr = f"10 <= {AGE} <= 14" + query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], + offset=0, limit=5, consistency_level=CONSISTENCY_LEVEL, + iteration_extension_reduce_rate=10) + page_idx = 0 + while True: + res = query_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + query_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + +def query_iterate_collection_with_offset(collection): + expr = f"10 <= {AGE} <= 14" + query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], + offset=10, limit=5, consistency_level=CONSISTENCY_LEVEL, + iteration_extension_reduce_rate=10) + page_idx = 0 + while True: + res = query_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + query_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + +def search_iterator_collection(collection): + SEARCH_NQ = 1 + DIM = 8 + rng = np.random.default_rng(seed=19530) + vectors_to_search = rng.random((SEARCH_NQ, DIM)) + search_params = { + "metric_type": "L2", + "params": {"nprobe": 10, "radius": 1.0}, + } + search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, limit=5, + output_fields=[USER_ID]) + page_idx = 0 + while True: + res = search_iterator.next() + if len(res[0]) == 0: + print("query iteration finished, close") + search_iterator.close() + break + for i in range(len(res[0])): + print(res[0][i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + + +def main(): + connections.connect("default", host=HOST, port=PORT) + collection = re_create_collection() + collection = prepare_data(collection) + query_iterate_collection_no_offset(collection) + query_iterate_collection_with_offset(collection) + search_iterator_collection(collection) + + +if __name__ == '__main__': + main() diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index a0b05a2e8..b65cc34bf 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -8,3 +8,4 @@ BOUNDED_TS = 2 DEFAULT_CONSISTENCY_LEVEL = ConsistencyLevel.Bounded DEFAULT_RESOURCE_GROUP = "__default_resource_group" +ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 48c04b710..cf19c6060 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -9,7 +9,7 @@ from .check import check_pass_param, is_legal_collection_properties from .types import DataType, PlaceholderType, get_consistency_level from .utils import traverse_info -from .constants import DEFAULT_CONSISTENCY_LEVEL +from .constants import DEFAULT_CONSISTENCY_LEVEL, ITERATION_EXTENSION_REDUCE_RATE from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage from ..orm.schema import CollectionSchema @@ -613,6 +613,10 @@ def query_request(cls, collection_name, expr, output_fields, partition_names, ** ignore_growing = kwargs.get("ignore_growing", False) req.query_params.append(common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing))) + + use_iteration_extension_reduce_rate = kwargs.get(ITERATION_EXTENSION_REDUCE_RATE, 0) + req.query_params.append(common_types.KeyValuePair(key=ITERATION_EXTENSION_REDUCE_RATE, + value=str(use_iteration_extension_reduce_rate))) return req @classmethod diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index ce7dbb328..3d9dd8643 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -43,11 +43,11 @@ from ..settings import Config from ..client.types import CompactionState, CompactionPlans, Replica, get_consistency_level, cmp_consistency_level from ..client.constants import DEFAULT_CONSISTENCY_LEVEL - +from .iterator import QueryIterator, SearchIterator class Collection: - def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", **kwargs): + def __init__(self, name: str, schema: CollectionSchema = None, using: str = "default", **kwargs): """ Constructs a collection by name, schema and other parameters. Args: @@ -394,7 +394,8 @@ def release(self, timeout=None, **kwargs): conn = self._get_connection() conn.release_collection(self._name, timeout=timeout, **kwargs) - def insert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: + def insert(self, data: Union[List, pandas.DataFrame], partition_name: str = None, timeout=None, + **kwargs) -> MutationResult: """ Insert data into the collection. Args: @@ -483,7 +484,8 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs): return MutationFuture(res) return MutationResult(res) - def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: + def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, + **kwargs) -> MutationResult: """ Upsert data into the collection. Args: @@ -523,7 +525,7 @@ def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, conn = self._get_connection() res = conn.upsert(self._name, entities, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + timeout=timeout, schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return MutationFuture(res) @@ -670,6 +672,15 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None return SearchFuture(res) return SearchResult(res) + def search_iterator(self, data, anns_field, param, limit, expr=None, partition_names=None, + output_fields=None, timeout=None, round_decimal=-1, **kwargs): + if expr is not None and not isinstance(expr, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + conn = self._get_connection() + iterator = SearchIterator(conn, self._name, data, anns_field, param, limit, expr, partition_names, + output_fields, timeout, round_decimal, schema=self._schema_dict, **kwargs) + return iterator + def query(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): """ Query with expressions @@ -751,6 +762,14 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** timeout=timeout, schema=self._schema_dict, **kwargs) return res + def query_iterator(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + if not isinstance(expr, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + conn = self._get_connection() + iterator = QueryIterator(conn, self._name, expr, output_fields, partition_names, + timeout=timeout, schema=self._schema_dict, **kwargs) + return iterator + @property def partitions(self, **kwargs) -> List[Partition]: """ List[Partition]: List of Partition object. diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 633244313..7aa30070e 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -12,7 +12,6 @@ COMMON_TYPE_PARAMS = ("dim", "max_length") - CALC_DIST_IDS = "ids" CALC_DIST_FLOAT_VEC = "float_vectors" CALC_DIST_BIN_VEC = "bin_vectors" @@ -23,3 +22,14 @@ CALC_DIST_TANIMOTO = "TANIMOTO" CALC_DIST_SQRT = "sqrt" CALC_DIST_DIM = "dim" + +OFFSET = "offset" +LIMIT = "limit" +ID = "id" +METRIC_TYPE = "metric_type" +PARAMS = "params" +DISTANCE = "distance" +RADIUS = "radius" +RANGE_FILTER = "range_filter" +FIELDS = "fields" +ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate" diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py new file mode 100644 index 000000000..cf79e001b --- /dev/null +++ b/pymilvus/orm/iterator.py @@ -0,0 +1,218 @@ +from .constants import OFFSET, LIMIT, ID, FIELDS, RANGE_FILTER, RADIUS, PARAMS, ITERATION_EXTENSION_REDUCE_RATE +from .types import DataType +from ..exceptions import ( + MilvusException, +) + + +class QueryIterator: + + def __init__(self, connection, collection_name, expr, output_fields=None, partition_names=None, schema=None, + timeout=None, **kwargs): + self._conn = connection + self._collection_name = collection_name + self._expr = expr + self._output_fields = output_fields + self._partition_names = partition_names + self._schema = schema + self._timeout = timeout + self._kwargs = kwargs + self.__setup__pk_is_str() + self.__seek() + self._cache_id_in_use = NO_CACHE_ID + + def __seek(self): + self._cache_id_in_use = NO_CACHE_ID + if self._kwargs.get(OFFSET, 0) == 0: + self._next_id = None + return + + first_cursor_kwargs = self._kwargs.copy() + first_cursor_kwargs[OFFSET] = 0 + # offset may be too large, needed to seek in multiple times + first_cursor_kwargs[LIMIT] = self._kwargs[OFFSET] + first_cursor_kwargs[ITERATION_EXTENSION_REDUCE_RATE] = 0 + + res = self._conn.query(self._collection_name, self._expr, self._output_fields, self._partition_names, + timeout=self._timeout, **first_cursor_kwargs) + self.__update_cursor(res) + self._kwargs[OFFSET] = 0 + + def __maybe_cache(self, result): + if len(result) < 2 * self._kwargs[LIMIT]: + return + start = self._kwargs[LIMIT] + cache_result = result[start:] + cache_id = iteratorCache.cache(cache_result, NO_CACHE_ID) + self._cache_id_in_use = cache_id + + def __is_res_sufficient(self, res): + return res is not None and len(res) >= self._kwargs[LIMIT] + + def next(self): + cached_res = iteratorCache.fetch_cache(self._cache_id_in_use) + ret = None + if self.__is_res_sufficient(cached_res): + ret = cached_res[0:self._kwargs[LIMIT]] + res_to_cache = cached_res[self._kwargs[LIMIT]:] + iteratorCache.cache(res_to_cache, self._cache_id_in_use) + else: + iteratorCache.release_cache(self._cache_id_in_use) + current_expr = self.__setup_next_expr() + res = self._conn.query(self._collection_name, current_expr, self._output_fields, self._partition_names, + timeout=self._timeout, **self._kwargs) + self.__maybe_cache(res) + ret = res[0:min(self._kwargs[LIMIT], len(res))] + self.__update_cursor(ret) + return ret + + def __setup__pk_is_str(self): + fields = self._schema[FIELDS] + for field in fields: + if field['is_primary']: + if field['type'] == DataType.VARCHAR: + self._pk_str = True + else: + self._pk_str = False + break + + def __setup_next_expr(self): + current_expr = self._expr + if self._next_id is None: + return current_expr + if self._next_id is not None: + if self._pk_str: + current_expr = self._expr + f" and id > \"{self._next_id}\"" + else: + current_expr = self._expr + f" and id > {self._next_id}" + return current_expr + + def __update_cursor(self, res): + if len(res) == 0: + return + self._next_id = res[-1][ID] + + def close(self): + # release cache in use + iteratorCache.release_cache(self._cache_id_in_use) + return + + +class SearchIterator: + + def __init__(self, connection, collection_name, data, ann_field, param, limit, expr=None, partition_names=None, + output_fields=None, timeout=None, round_decimal=-1, schema=None, **kwargs): + if len(data) > 1: + raise MilvusException("Not support multiple vector iterator at present") + self._conn = connection + self._iterator_params = {'collection_name': collection_name, "data": data, + "ann_field": ann_field, "limit": limit, + "output_fields": output_fields, "partition_names": partition_names, + "timeout": timeout, "round_decimal": round_decimal} + self._expr = expr + self._param = param + self._kwargs = kwargs + self._distance_cursor = [0.0] + self._filtered_ids = [] + self._schema = schema + self.__check_radius() + self.__seek() + self.__setup__pk_is_str() + + def __setup__pk_is_str(self): + fields = self._schema[FIELDS] + for field in fields: + if field['is_primary']: + if field['type'] == DataType.VARCHAR: + self._pk_str = True + else: + self._pk_str = False + break + + def __check_radius(self): + if self._param[PARAMS][RADIUS] is None: + raise MilvusException(message="must provide radius parameter when using search iterator") + + def __seek(self): + if self._kwargs.get(OFFSET, 0) != 0: + raise MilvusException("Not support offset when searching iteration") + + def __update_cursor(self, res): + if len(res[0]) == 0: + return + last_hit = res[0][-1] + self._distance_cursor[0] = last_hit.distance + self._filtered_ids = [] + for hit in res[0]: + if hit.distance == last_hit.distance: + self._filtered_ids.append(hit.id) + + def next(self): + next_params = self.__next_params() + next_expr = self.__filtered_duplicated_result_expr(self._expr) + res = self._conn.search(self._iterator_params['collection_name'], + self._iterator_params['data'], + self._iterator_params['ann_field'], + next_params, + self._iterator_params['limit'], + next_expr, + self._iterator_params['partition_names'], + self._iterator_params['output_fields'], + self._iterator_params['round_decimal'], + timeout=self._iterator_params['timeout'], + schema=self._schema, **self._kwargs) + self.__update_cursor(res) + return res + + # at present, the range_filter parameter means 'larger/less and equal', + # so there would be vectors with same distances returned multiple times in different pages + # we need to refine and remove these results before returning + def __filtered_duplicated_result_expr(self, expr): + if len(self._filtered_ids) == 0: + return expr + + filtered_ids_str = "" + for filtered_id in self._filtered_ids: + if self._pk_str: + filtered_ids_str += f"\"{filtered_id}\", " + else: + filtered_ids_str += f"{filtered_id}, " + + filter_expr = f"id not in [{filtered_ids_str}]" + if expr is not None: + return expr + filter_expr + return filter_expr + + def __next_params(self): + next_params = self._param.copy() + next_params[PARAMS][RANGE_FILTER] = self._distance_cursor[0] + return next_params + + def close(self): + pass + + +class IteratorCache: + + def __init__(self): + self._cache_id = 0 + self._cache_map = {} + + def cache(self, result, cache_id): + if cache_id == NO_CACHE_ID: + self._cache_id += 1 + cache_id = self._cache_id + self._cache_map[cache_id] = result + return cache_id + + def fetch_cache(self, cache_id): + return self._cache_map.get(cache_id, None) + + def release_cache(self, cache_id): + if self._cache_map.get(cache_id, None) is not None: + self._cache_map.pop(cache_id) + + +NO_CACHE_ID = -1 +# Singleton Mode in Python +iteratorCache = IteratorCache()