Skip to content

Commit

Permalink
add hybrid_search for MilvusClient (#2259)
Browse files Browse the repository at this point in the history
Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Sep 10, 2024
1 parent 07052c5 commit c7de801
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")
if has:
utility.drop_collection("hello_milvus")

fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
Expand Down
75 changes: 75 additions & 0 deletions examples/milvus_client/hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
from pymilvus import (
MilvusClient,
DataType,
AnnSearchRequest, RRFRanker, WeightedRanker,
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8

collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)

schema = milvus_client.create_schema(auto_id=False, description="hello_milvus is the simplest demo to introduce the APIs")
schema.add_field("pk", DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field("random", DataType.DOUBLE)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("embeddings2", DataType.FLOAT_VECTOR, dim=dim)

index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", index_type = "IVF_FLAT", metric_type="L2", nlist=128)
index_params.add_index(field_name = "embeddings2",index_type = "IVF_FLAT", metric_type="L2", nlist=128)

print(fmt.format("Create collection `hello_milvus`"))

milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")

print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)
entities = [
# provide the pk field because `auto_id` is set to False
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(), # field random, only supports list
rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list
rng.random((num_entities, dim)), # field embeddings2, supports numpy.ndarray and list
]

rows = [ {"pk": entities[0][i], "random": entities[1][i], "embeddings": entities[2][i], "embeddings2": entities[3][i]} for i in range (num_entities)]

insert_result = milvus_client.insert(collection_name, rows)


print(fmt.format("Start loading"))
milvus_client.load_collection(collection_name)

field_names = ["embeddings", "embeddings2"]
field_names = ["embeddings"]

req_list = []
nq = 1
default_limit = 5
vectors_to_search = []

for i in range(len(field_names)):
# 4. generate search data
vectors_to_search = rng.random((nq, dim))
search_param = {
"data": vectors_to_search,
"anns_field": field_names[i],
"param": {"metric_type": "L2"},
"limit": default_limit,
"expr": "random > 0.5"}
req = AnnSearchRequest(**search_param)
req_list.append(req)

print("rank by RRFRanker")
hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")
72 changes: 72 additions & 0 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional, Union
from uuid import uuid4

from pymilvus.client.abstract import AnnSearchRequest, BaseRanker
from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL
from pymilvus.client.types import (
ExceptionsMessage,
Expand Down Expand Up @@ -282,6 +283,77 @@ def upsert(
}
)

def hybrid_search(
self,
collection_name: str,
reqs: List[AnnSearchRequest],
ranker: BaseRanker,
limit: int = 10,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs,
) -> List[List[dict]]:
"""Conducts multi vector similarity search with a rerank for rearrangement.
Args:
collection_name(``string``): The name of collection.
reqs (``List[AnnSearchRequest]``): The vector search requests.
ranker (``BaseRanker``): The ranker for rearrange nummer of limit results.
limit (``int``): The max number of returned record, also known as `topk`.
partition_names (``List[str]``, optional): The names of partitions to search on.
output_fields (``List[str]``, optional):
The name of fields to return in the search result. Can only get scalar fields.
round_decimal (``int``, optional):
The specified number of decimal places of returned distance.
Defaults to -1 means no round to returned distance.
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
If timeout is set to None, the client keeps waiting until the server
responds or an error occurs.
**kwargs (``dict``): Optional search params
* *offset* (``int``, optinal)
offset for pagination.
* *consistency_level* (``str/int``, optional)
Which consistency level to use when searching in the collection.
Options of consistency level: Strong, Bounded, Eventually, Session, Customized.
Note: this parameter overwrites the same one specified when creating collection,
if no consistency level was specified, search will use the
consistency level when you create the collection.
Returns:
List[List[dict]]: A nested list of dicts containing the result data.
Raises:
MilvusException: If anything goes wrong
"""

conn = self._get_connection()
try:
res = conn.hybrid_search(
collection_name,
reqs,
ranker,
limit=limit,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
except Exception as ex:
logger.error("Failed to hybrid search collection: %s", collection_name)
raise ex from ex

ret = []
for hits in res:
ret.append([hit.to_dict() for hit in hits])

return ExtraList(ret, extra=construct_cost_extra(res.cost))

def search(
self,
collection_name: str,
Expand Down

0 comments on commit c7de801

Please sign in to comment.