Skip to content

Commit

Permalink
support iterator for query and search
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed May 18, 2023
1 parent a5c79ba commit 73751ce
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 7 deletions.
145 changes: 145 additions & 0 deletions examples/iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import numpy as np
import random
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

HOST = "localhost"
PORT = "19530"
COLLECTION_NAME = "test_iterator"
USER_ID = "id"
MAX_LENGTH = 65535
AGE = "age"
DEPOSIT = "deposit"
PICTURE = "picture"
CONSISTENCY_LEVEL = "Eventually"
LIMIT = 5
NUM_ENTITIES = 1000
DIM = 8
CLEAR_EXIST = False


def re_create_collection():
if utility.has_collection(COLLECTION_NAME) and CLEAR_EXIST:
utility.drop_collection(COLLECTION_NAME)
print(f"dropped existed collection{COLLECTION_NAME}")

fields = [
FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True,
auto_id=False, max_length=MAX_LENGTH),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
]

schema = CollectionSchema(fields)
print(f"Create collection {COLLECTION_NAME}")
collection = Collection(COLLECTION_NAME, schema, consistency_level=CONSISTENCY_LEVEL)
return collection


def insert_data(collection):
rng = np.random.default_rng(seed=19530)
batch_count = 5
for i in range(batch_count):
entities = [
[str(random.randint(NUM_ENTITIES * i, NUM_ENTITIES * (i + 1))) for ni in range(NUM_ENTITIES)],
[int(ni % 100) for ni in range(NUM_ENTITIES)],
[float(ni) for ni in range(NUM_ENTITIES)],
rng.random((NUM_ENTITIES, DIM)),
]
collection.insert(entities)
collection.flush()
print(f"Finish insert batch{i}, number of entities in Milvus: {collection.num_entities}")

def prepare_index(collection):
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

collection.create_index(PICTURE, index)
print("Finish Creating index IVF_FLAT")
collection.load()
print("Finish Loading index IVF_FLAT")


def prepare_data(collection):
insert_data(collection)
prepare_index(collection)
return collection


def query_iterate_collection_no_offset(collection):
expr = f"10 <= {AGE} <= 14"
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, limit=5, consistency_level=CONSISTENCY_LEVEL,
iteration_extension_reduce_rate=10)
page_idx = 0
while True:
res = query_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
query_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")

def query_iterate_collection_with_offset(collection):
expr = f"10 <= {AGE} <= 14"
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=10, limit=5, consistency_level=CONSISTENCY_LEVEL,
iteration_extension_reduce_rate=10)
page_idx = 0
while True:
res = query_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
query_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")

def search_iterator_collection(collection):
SEARCH_NQ = 1
DIM = 8
rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((SEARCH_NQ, DIM))
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10, "radius": 1.0},
}
search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, limit=5,
output_fields=[USER_ID])
page_idx = 0
while True:
res = search_iterator.next()
if len(res[0]) == 0:
print("query iteration finished, close")
search_iterator.close()
break
for i in range(len(res[0])):
print(res[0][i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def main():
connections.connect("default", host=HOST, port=PORT)
collection = re_create_collection()
collection = prepare_data(collection)
query_iterate_collection_no_offset(collection)
query_iterate_collection_with_offset(collection)
search_iterator_collection(collection)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
BOUNDED_TS = 2
DEFAULT_CONSISTENCY_LEVEL = ConsistencyLevel.Bounded
DEFAULT_RESOURCE_GROUP = "__default_resource_group"
ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate"
6 changes: 5 additions & 1 deletion pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .check import check_pass_param, is_legal_collection_properties
from .types import DataType, PlaceholderType, get_consistency_level
from .utils import traverse_info
from .constants import DEFAULT_CONSISTENCY_LEVEL
from .constants import DEFAULT_CONSISTENCY_LEVEL, ITERATION_EXTENSION_REDUCE_RATE
from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage
from ..orm.schema import CollectionSchema

Expand Down Expand Up @@ -613,6 +613,10 @@ def query_request(cls, collection_name, expr, output_fields, partition_names, **

ignore_growing = kwargs.get("ignore_growing", False)
req.query_params.append(common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)))

use_iteration_extension_reduce_rate = kwargs.get(ITERATION_EXTENSION_REDUCE_RATE, 0)
req.query_params.append(common_types.KeyValuePair(key=ITERATION_EXTENSION_REDUCE_RATE,
value=str(use_iteration_extension_reduce_rate)))
return req

@classmethod
Expand Down
29 changes: 24 additions & 5 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
from ..settings import Config
from ..client.types import CompactionState, CompactionPlans, Replica, get_consistency_level, cmp_consistency_level
from ..client.constants import DEFAULT_CONSISTENCY_LEVEL

from .iterator import QueryIterator, SearchIterator


class Collection:
def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", **kwargs):
def __init__(self, name: str, schema: CollectionSchema = None, using: str = "default", **kwargs):
""" Constructs a collection by name, schema and other parameters.
Args:
Expand Down Expand Up @@ -394,7 +394,8 @@ def release(self, timeout=None, **kwargs):
conn = self._get_connection()
conn.release_collection(self._name, timeout=timeout, **kwargs)

def insert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
def insert(self, data: Union[List, pandas.DataFrame], partition_name: str = None, timeout=None,
**kwargs) -> MutationResult:
""" Insert data into the collection.
Args:
Expand Down Expand Up @@ -483,7 +484,8 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs):
return MutationFuture(res)
return MutationResult(res)

def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None,
**kwargs) -> MutationResult:
""" Upsert data into the collection.
Args:
Expand Down Expand Up @@ -523,7 +525,7 @@ def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None,

conn = self._get_connection()
res = conn.upsert(self._name, entities, partition_name,
timeout=timeout, schema=self._schema_dict, **kwargs)
timeout=timeout, schema=self._schema_dict, **kwargs)

if kwargs.get("_async", False):
return MutationFuture(res)
Expand Down Expand Up @@ -670,6 +672,15 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
return SearchFuture(res)
return SearchResult(res)

def search_iterator(self, data, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
conn = self._get_connection()
iterator = SearchIterator(conn, self._name, data, anns_field, param, limit, expr, partition_names,
output_fields, timeout, round_decimal, schema=self._schema_dict, **kwargs)
return iterator

def query(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs):
""" Query with expressions
Expand Down Expand Up @@ -751,6 +762,14 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, **
timeout=timeout, schema=self._schema_dict, **kwargs)
return res

def query_iterator(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs):
if not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
conn = self._get_connection()
iterator = QueryIterator(conn, self._name, expr, output_fields, partition_names,
timeout=timeout, schema=self._schema_dict, **kwargs)
return iterator

@property
def partitions(self, **kwargs) -> List[Partition]:
""" List[Partition]: List of Partition object.
Expand Down
12 changes: 11 additions & 1 deletion pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

COMMON_TYPE_PARAMS = ("dim", "max_length")


CALC_DIST_IDS = "ids"
CALC_DIST_FLOAT_VEC = "float_vectors"
CALC_DIST_BIN_VEC = "bin_vectors"
Expand All @@ -23,3 +22,14 @@
CALC_DIST_TANIMOTO = "TANIMOTO"
CALC_DIST_SQRT = "sqrt"
CALC_DIST_DIM = "dim"

OFFSET = "offset"
LIMIT = "limit"
ID = "id"
METRIC_TYPE = "metric_type"
PARAMS = "params"
DISTANCE = "distance"
RADIUS = "radius"
RANGE_FILTER = "range_filter"
FIELDS = "fields"
ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate"
Loading

0 comments on commit 73751ce

Please sign in to comment.