Skip to content

Commit

Permalink
fix conflict for ef and batch size
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Sep 22, 2023
1 parent bac3195 commit dbfba67
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
RANGE_FILTER = "range_filter"
FIELDS = "fields"
ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate"
EF = "ef"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0
Expand Down
19 changes: 17 additions & 2 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CALC_DIST_L2,
CALC_DIST_TANIMOTO,
DEFAULT_SEARCH_EXTENSION_RATE,
EF,
FIELDS,
INT64_MAX,
ITERATION_EXTENSION_REDUCE_RATE,
Expand All @@ -39,7 +40,11 @@
SearchIterator = TypeVar("SearchIterator")


def extend_batch_size(batch_size: int) -> int:
def extend_batch_size(batch_size: int, next_param: dict) -> int:
if EF in next_param[PARAMS]:
return min(
MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE, next_param[PARAMS][EF]
)
return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE)


Expand Down Expand Up @@ -294,6 +299,7 @@ def __init__(
}
self._expr = expr
self.__check_set_params(param)
self.__check_for_special_index_param()
self._kwargs = kwargs
self._filtered_ids = []
self._filtered_distance = None
Expand Down Expand Up @@ -337,6 +343,15 @@ def __check_set_params(self, param: Dict):
if PARAMS not in self._param:
self._param[PARAMS] = {}

def __check_for_special_index_param(self):
if (
EF in self._param[PARAMS]
and self._param[PARAMS][EF] < self._iterator_params[BATCH_SIZE]
):
raise MilvusException(
message="When using hnsw index, provided ef must be larger than or equal to batch size"
)

def __setup__pk_prop(self):
fields = self._schema[FIELDS]
for field in fields:
Expand Down Expand Up @@ -472,7 +487,7 @@ def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage
self._iterator_params["data"],
self._iterator_params["ann_field"],
next_params,
extend_batch_size(self._iterator_params[BATCH_SIZE]),
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params),
next_expr,
self._iterator_params["partition_names"],
self._iterator_params["output_fields"],
Expand Down

0 comments on commit dbfba67

Please sign in to comment.