diff --git a/pymilvus/milvus_client/__init__.py b/pymilvus/milvus_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymilvus/milvus_client/defaults.py b/pymilvus/milvus_client/defaults.py new file mode 100644 index 000000000..446e938d6 --- /dev/null +++ b/pymilvus/milvus_client/defaults.py @@ -0,0 +1,12 @@ +"""Default MilvusClient args.""" + +DEFAULT_SEARCH_PARAMS = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, +} diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py new file mode 100644 index 000000000..106b550b0 --- /dev/null +++ b/pymilvus/milvus_client/milvus_client.py @@ -0,0 +1,869 @@ +"""MilvusClient for dealing with simple workflows.""" +import logging +import threading +from typing import Union, List, Dict +from uuid import uuid4 +from tqdm import tqdm + +from pymilvus.exceptions import MilvusException +from pymilvus.milvus_client.defaults import DEFAULT_SEARCH_PARAMS +from pymilvus.orm import utility +from pymilvus.orm.collection import Collection, CollectionSchema, FieldSchema +from pymilvus.orm.connections import connections +from pymilvus.orm.types import DataType, infer_dtype_bydata + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + + +class MilvusClient: + """The Milvus Client""" + + # pylint: disable=logging-too-many-args, too-many-instance-attributes + + def __init__( + self, + collection_name: str = "ClientCollection", + pk_field: str = None, + vector_field: str = None, + uri: str = None, + shard_num: int = None, + partitions: List[str] = None, + consistency_level: str = "Bounded", + replica_number: int = 1, + index_params: dict = None, + timeout: int = None, + overwrite: bool = False, + ): + """A client for the common Milvus use case. + + This client attempts to hide away the complexity of using Pymilvus. In a lot ofcases what + the user wants is a simple wrapper that supports adding data, deleting data, and searching. + This wrapper can autoinfer the schema from a previous collection or newly inserted data, + can update the paritions, can query, and can delete by pk. + + Args: + pk_field (str, optional): Which entry in data is considered the primary key. If None, + an auto-id will be created. Will be overwritten if loading from a previous + collection. Defaults to None. + vector_field (str, optional): Which entry in the data is considered the vector field. + Will get overwritten if loading from previous collection. Required if not using + already made collection. + uri (str, optional): The connection address to use to connect to the + instance. Defaults to "http://localhost:19530". Another example: + "https://username:password@in01-12a.aws-us-west-2.vectordb.zillizcloud.com:19538 + shard_num (int, optional): The amount of shards to use for the collection. Unless + dealing with huge scale, recommended to keep at default. Defaults to None and allows + server to set. + partitions (List[str], optional): Which paritions to create for the collection. + Defaults to None. + consistency_level (str, optional): Which consistency level to use for the Client. + The options are "Strong", "Bounded", "Eventually", "Session". Defaults to "Bounded". + replica_number (int, optional): The amount of in memomory replicas to use. + Defaults to 1. + index_params (dict, optional): What index parameteres to use for the Collection. + If none, will use a default one. If collection already exists, will overwrite + using this index. + timeout (int, optional): What timeout to use for function calls. Defaults + to None. + overwrite (bool, optional): Whether to overwrite existing collection if exists. + Defaults to False + """ + self.uri = uri + self.collection_name = collection_name + self.shard_num = shard_num + self.partitions = partitions + self.consistency_level = consistency_level + self.replica_number = replica_number + self.index_params = index_params + self.timeout = timeout + self.pk_field = pk_field + self.vector_field = vector_field + + # TODO: Figure out thread safety + # self.concurrent_counter = 0 + self.concurrent_lock = threading.RLock() + self.default_search_params = None + self.collection = None + self.fields = None + + self.alias = self._create_connection() + self.is_self_hosted = bool( + utility.get_server_type(using=self.alias) == "milvus" + ) + if overwrite and utility.has_collection(self.collection_name, using=self.alias): + utility.drop_collection(self.collection_name, using=self.alias) + + self._init(None) + + def __len__(self): + return self.num_entities() + + def num_entities(self): + """return the number of rows in the collection. + + Returns: + int: Number for rows. + """ + if self.collection is None: + return 0 + + self.collection.flush() + return self.collection.num_entities + + def insert_data( + self, + data: List[Dict[str, any]], + timeout: int = None, + batch_size: int = 100, + partition: str = None, + progress_bar: bool = False, + ) -> List[Union[str, int]]: + """Insert data into the collection. + + If the Milvus Client was initiated without an existing Collection, the first dict passed + in will be used to initiate the collection. + + Args: + data (List[Dict[str, any]]): A list of dicts to pass in. If list not provided, will + cast to list. + timeout (int, optional): The timeout to use, will override init timeout. Defaults + to None. + batch_size (int, optional): The batch size to perform inputs with. Defaults to 100. + partition (str, optional): Which partition to insert into. Defaults to None. + progress_bar (bool, optional): Whether to display a progress bar for the input. + Defaults to False. + + Raises: + DataNotMatchException: If the data has misssing fields an exception will be thrown. + MilvusException: General Milvus error on insert. + + Returns: + List[Union[str, int]]: A list of primary keys that were inserted. + """ + # If no data provided, we cannot input anything + if len(data) == 0: + return [] + + if batch_size < 1: + logger.error("Invalid batch size provided for insert.") + + raise ValueError("Invalid batch size provided for insert.") + + # If the collection hasnt been initialized, initialize it + with self.concurrent_lock: + if self.collection is None: + self._init(data[0]) + + # Dont include the primary key if auto_id is true and they included it in data + ignore_pk = self.pk_field if self.collection.schema.auto_id else None + insert_dict = {} + pks = [] + + for k in data: + for key, value in k.items(): + if key in self.fields: + insert_dict.setdefault(key, []).append(value) + + # Insert the data in batches + for i in tqdm(range(0, len(data), batch_size), disable=not progress_bar): + # Convert dict to list of lists batch for insertion + try: + insert_batch = [ + insert_dict[key][i : i + batch_size] + for key in self.fields + if key != ignore_pk + ] + except KeyError as ex: + logger.error( + "Malformed data, at least one of the inserts does not contain all" + " the required fields." + ) + raise KeyError( + f"Malformed data, at least one of the inserts does not" + f" the required fields: {ex}", + ) from ex + # Insert into the collection. + try: + res = self.collection.insert( + insert_batch, + timeout=timeout or self.timeout, + partition_name=partition, + ) + pks.extend(res.primary_keys) + except MilvusException as ex: + logger.error( + "Failed to insert batch starting at entity: %s/%s", + str(i), + str(len(data)), + ) + raise ex + return pks + + def upsert_data( + self, + data: List[Dict[str, any]], + timeout: int = None, + batch_size: int = 100, + partition: str = None, + progress_bar: bool = False, + ) -> List[Union[str, int]]: + """WARNING: SLOW AND NOT ATOMIC. Will be updated for 2.3 release. + + Upsert the data into the collection. + + If the Milvus Client was initiated without an existing Collection, the first dict passed + in will be used to initiate the collection. + + Args: + data (List[Dict[str, any]]): A list of dicts to upsert. + timeout (int, optional): The timeout to use, will override init timeout. Defaults + to None. + batch_size (int, optional): The batch size to perform inputs with. Defaults to 100. + partition (str, optional): Which partition to insert into. Defaults to None. + progress_bar (bool, optional): Whether to display a progress bar for the input. + Defaults to False. + Returns: + List[Union[str, int]]: A list of primary keys that were inserted. + """ + # If the collection exists we need to first delete the values + if self.collection is not None: + pks = [x[self.pk_field] for x in data] + self.delete_by_pk(pks, timeout) + + ret = self.insert_data( + data=data, + timeout=timeout, + batch_size=batch_size, + partition=partition, + progress_bar=progress_bar, + ) + + return ret + + def search_data( + self, + data: Union[List[list], list], + top_k: int = 10, + filter_expression: str = None, + return_fields: List[str] = None, + partitions: List[str] = None, + search_params: dict = None, + timeout: int = None, + ) -> List[dict]: + """Search for a query vector/vectors. + + In order for the search to process, a collection needs to have been either provided + at init or data needs to have been inserted. + + Args: + data (Union[List[list], list]): The vector/vectors to search. + top_k (int, optional): How many results to return per search. Defaults to 10. + filter_expression (str, optional): A filter to use for the search. Defaults to None. + return_fields (List[str], optional): List of which field values to return. If None + specified, all fields excluding vector field will be returned. + search_params (dict, optional): The search params to use for the search. Will default + to the default set for the client. + + + partitions (List[str], optional): Which partitions to search within. Defaults to + searching through all. + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + + Raises: + ValueError: The collection being searched doesnt exist. Need to insert data first. + + Returns: + List[dict]: A list of dicts containing the score and the result data. Embeddings are + not included in the result data. + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + if not isinstance(data[0], list): + data = [data] + if return_fields is None or len(return_fields) == 0: + return_fields = list(self.fields.keys()) + return_fields.remove(self.vector_field) + + try: + res = self.collection.search( + data, + anns_field=self.vector_field, + expr=filter_expression, + param=search_params or self.default_search_params, + limit=top_k, + partition_names=partitions, + output_fields=return_fields, + timeout=timeout or self.timeout, + ) + except Exception as ex: + logger.error("Failed to search collection: %s", self.collection_name) + raise ex + + ret = [] + for hits in res: + query_result = [] + for hit in hits: + ret_dict = {x: hit.entity.get(x) for x in return_fields} + query_result.append({"score": hit.score, "data": ret_dict}) + ret.append(query_result) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + return ret + + def query_data( + self, + filter_expression: str, + return_fields: List[str] = None, + partitions: List[str] = None, + timeout: int = None, + ) -> List[dict]: + """Query for entries in the Collection. + + Args: + filter_expression (str): The filter to use for the query. + return_fields (List[str], optional): List of which field values to return. If None + specified, all fields excluding vector field will be returned. + partitions (List[str], optional): Which partitions to perform query. Defaults to None. + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + + Raises: + ValueError: Missing collection. + + Returns: + List[dict]: A list of result dicts, vectors are not included. + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + if return_fields is None or len(return_fields) == 0: + return_fields = list(self.fields.keys()) + return_fields.remove(self.vector_field) + + res = self.collection.query( + expr=filter_expression, + partition_names=partitions, + output_fields=return_fields, + timeout=timeout or self.timeout, + ) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + return res + + def get_vectors_by_pk( + self, + pks: Union[list, str, int], + timeout: int = None, + ) -> None: + """Grab the inserted vectors using the primary key from the Collection. + + Due to current implementations, grabbing a large amount of vectors is slow. + + Args: + pks (str): The pk's to get vectors for. Depending on pk_field type it can be int or str + or a list of either. + timeout (int, optional): Timeout to use, overides the client level assigned at + init. Defaults to None. + + Raises: + ValueError: Missing collection. + + Returns: + List[dict]: A list of result dicts with keys {pk_field, vector_field} + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + if not isinstance(pks, list): + pks = [pks] + + if len(pks) == 0: + return [] + + # Varchar pks need double quotes around the values + if self.fields[self.pk_field] == DataType.VARCHAR: + ids = ['"' + str(entry) + '"' for entry in pks] + expr = f"""{self.pk_field} in [{','.join(ids)}]""" + else: + ids = [str(entry) for entry in pks] + expr = f"{self.pk_field} in [{','.join(ids)}]" + + res = self.collection.query( + expr=expr, + output_fields=[self.vector_field], + timeout=timeout or self.timeout, + ) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + return res + + def delete_by_pk( + self, + pks: Union[list, str, int], + timeout: int = None, + ) -> None: + """Delete entries in the collection by their pk. + + Delete all the entries based on the pk. If unsure of pk you can first query the collection + to grab the corresponding data. Then you can delete using the pk_field. + + Args: + pks (list, str, int): The pk's to delete. Depending on pk_field type it can be int + or str or alist of either. + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + return + + if not isinstance(pks, list): + pks = [pks] + + if len(pks) == 0: + return + + if self.fields[self.pk_field] == DataType.VARCHAR: + ids = ['"' + str(entry) + '"' for entry in pks] + expr = f"""{self.pk_field} in [{','.join(ids)}]""" + else: + ids = [str(entry) for entry in pks] + expr = f"{self.pk_field} in [{','.join(ids)}]" + + self.collection.delete(expr=expr, timout=timeout or self.timeout) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + def add_partitions(self, input_partitions: List[str]): + """Add partitions to the collection. + + Add a list of partition names to the collection. If the collection is loaded + it will first be unloaded, then the partitions will be added, and then reloaded. + + Args: + input_partitions (List[str]): The list of partition names to be added. + + Raises: + MilvusException: Unable to add the partition. + """ + if self.collection is not None and self.is_self_hosted: + # Calculate which partitions need to be added + input_partitions = set(input_partitions) + current_partitions = { + partition.name for partition in self.collection.partitions + } + new_partitions = input_partitions.difference(current_partitions) + # If partitions need to be added, add them + if len(new_partitions) != 0: + # TODO: Remove with Milvus 2.3 + # Try to unload the collection + self.collection.release() + try: + for part in new_partitions: + self.collection.create_partition(part) + logger.debug( + "Successfully added partitions to collection: %s partitions: %s", + self.collection_name, + ",".join(part for part in list(new_partitions)), + ) + # TODO: Remove with Milvus 2.3 + self._load() + except MilvusException as ex: + logger.debug( + "Failed to add partitions to: %s", self.collection_name + ) + # TODO: Remove with Milvus 2.3 + # Even if failed, attempt to reload collection + self._load() + raise ex + else: + logger.debug( + "No parititons to add for collection: %s", self.collection_name + ) + else: + logger.debug( + "Collection either on Zilliz or non existant for collection: %s", + self.collection_name, + ) + + def delete_partitions(self, remove_partitions: List[str]): + """Remove partitions from the collection. + + Remove a list of partition names from the collection. If the collection is loaded + it will first be unloaded, then the partitions will be removed, and then reloaded. + + Args: + remove_partitions (List[str]): The list of partition names to be removed. + + Raises: + MilvusException: Unable to remove the partition. + """ + if self.collection is not None and self.is_self_hosted: + # Calculate which partitions need to be removed + remove_partitions = set(remove_partitions) + current_partitions = { + partition.name for partition in self.collection.partitions + } + removal_partitions = remove_partitions.intersection(current_partitions) + # If partitions need to be added, add them + if len(removal_partitions) != 0: + # TODO: Remove with Milvus 2.3 + # Try to unload the collection + self.collection.release() + try: + for part in removal_partitions: + self.collection.drop_partition(part) + logger.debug( + "Successfully deleted partitions from collection: %s partitions: %s", + self.collection_name, + ",".join(part for part in list(removal_partitions)), + ) + # TODO: Remove with Milvus 2.3 + self._load() + except MilvusException as ex: + logger.debug( + "Failed to delete partitions from: %s", self.collection_name + ) + # TODO: Remove with Milvus 2.3 + # Even if failed, attempt to reload collection + self._load() + raise ex + else: + logger.debug( + "No parititons to delete for collection: %s", + self.collection_name, + ) + + def delete_collection(self): + """Delete the collection stored in this object""" + with self.concurrent_lock: + if self.collection is None: + return + self.collection.drop() + self.collection = None + + def close(self, delete_collection=False): + if delete_collection: + self.delete_collection() + connections.disconnect(self.alias) + + def _create_connection(self) -> str: + """Create the connection to the Milvus server.""" + # TODO: Implement reuse with new uri style + alias = uuid4().hex + try: + connections.connect(alias=alias, uri=self.uri) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as ex: + logger.error("Failed to create new connection using: %s", alias) + raise ex + + def _init(self, input_data: dict): + """Create/connect to the colletion""" + # If no input data and collection exists, use that + if input_data is None and utility.has_collection( + self.collection_name, using=self.alias + ): + self.collection = Collection(self.collection_name, using=self.alias) + # Grab the field information from the existing collection + self._extract_fields() + # If data is supplied we can create a new collection + elif input_data is not None: + self._create_collection(input_data) + # Nothin to init from + else: + logger.debug( + "No information to perform init from for collection %s", + self.collection_name, + ) + return + + # TODO: Make sure this drops the correct index + if self.index_params is not None: + self.collection.drop_index() + + self._create_index() + # Partitions only allowed on Milvus at the moment + if self.is_self_hosted and self.partitions is not None: + self.add_partitions(self.partitions) + self._create_default_search_params() + self._load() + + def _create_collection(self, data: dict) -> None: + """Create the collection by autoinferring the schema.""" + + fields = self._infer_fields(data) + + if self.vector_field is None: + logger.error( + "vector_field not supplied at init(), cannot infer schema from data collection: %s", + self.collection_name, + ) + raise ValueError( + "vector_field not supplied at init(), cannot infer schema." + ) + + if self.vector_field not in fields: + logger.error( + "Missing vector_field: %s in data for collection: %s", + self.vector_field, + self.collection_name, + ) + raise ValueError( + "vector_field missing in inserted data, cannot infer schema." + ) + + if fields[self.vector_field]["dtype"] not in ( + DataType.BINARY_VECTOR, + DataType.FLOAT_VECTOR, + ): + logger.error( + "vector_field: %s does not correspond with vector dtype in data for collection: %s", + self.vector_field, + self.collection_name, + ) + raise ValueError("vector_field does not correspond to vector dtype.") + + if fields[self.vector_field]["dtype"] == DataType.BINARY_VECTOR: + dim = 8 * len(data[self.vector_field]) + elif fields[self.vector_field]["dtype"] == DataType.FLOAT_VECTOR: + dim = len(data[self.vector_field]) + # Attach dim kwarg to vector field + fields[self.vector_field]["dim"] = dim + + # If pk not provided, created autoid pk + if self.pk_field is None: + # Generate a unique auto-id field + self.pk_field = "internal_pk_" + uuid4().hex[:4] + # Create a new field for pk + fields[self.pk_field] = {} + fields[self.pk_field]["name"] = self.pk_field + fields[self.pk_field]["dtype"] = DataType.INT64 + fields[self.pk_field]["auto_id"] = True + fields[self.pk_field]["is_primary"] = True + logger.debug( + "Missing pk_field, creating auto-id pk for collection: %s", + self.collection_name, + ) + # If pk_field given, we assume it will be provided for all inputs + else: + try: + fields[self.pk_field]["auto_id"] = False + fields[self.pk_field]["is_primary"] = True + except KeyError as ex: + logger.error( + "Missing pk_field: %s in data for collection: %s", + self.pk_field, + self.collection_name, + ) + raise ex + try: + # Create the fieldschemas + fieldschemas = [] + # TODO: Assuming ordered dicts for 3.7 + self.fields = {} + for field_dict in fields.values(): + fieldschemas.append(FieldSchema(**field_dict)) + self.fields[field_dict["name"]] = field_dict["dtype"] + # Create the schema for the collection + schema = CollectionSchema(fieldschemas) + # Create the collection + self.collection = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + shards_num=self.shard_num, + using=self.alias, + ) + logger.debug("Successfully created collection: %s", self.collection_name) + except MilvusException as ex: + logger.error("Failed to create collection: %s", self.collection_name) + raise ex + + def _infer_fields(self, data): + """Infer all the fields based on the input data.""" + # TODO: Assuming ordered dict for 3.7 + fields = {} + # Figure out each datatype of the input. + for key, value in data.items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + # Datatype isnt compatible + if dtype in (DataType.UNKNOWN, DataType.NONE): + logger.error( + "Failed to parse schema for collection %s, unrecognized dtype for key: %s", + self.collection_name, + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + + # Create an entry under the field name + fields[key] = {} + fields[key]["name"] = key + fields[key]["dtype"] = dtype + + # Area for attaching kwargs for certain datatypes + if dtype == DataType.VARCHAR: + fields[key]["max_length"] = 65_535 + + return fields + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + self.fields = {} + schema = self.collection.schema + for field in schema.fields: + field_dict = field.to_dict() + if field_dict.get("is_primary", None) is not None: + logger.debug("Updating pk_field with one from collection.") + self.pk_field = field_dict["name"] + if field_dict["type"] in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + logger.debug("Updating vector_field with one from collection.") + self.vector_field = field_dict["name"] + self.fields[field_dict["name"]] = field_dict["type"] + + logger.debug( + "Successfully extracted fields from for collection: %s, total fields: %s, " + "pk_field: %s, vector_field: %s", + self.collection_name, + len(self.fields), + self.pk_field, + self.vector_field, + ) + + def _create_index(self) -> None: + """Create a index on the collection""" + if self._get_index() is None: + # If no index params, use a default HNSW based one + if self.index_params is None: + # TODO: Once segment normalization we can default to IP + metric_type = ( + "L2" + if self.fields[self.vector_field] == DataType.FLOAT_VECTOR + else "JACCARD" + ) + # TODO: Once AUTOINDEX type is supported by Milvus we can default to HNSW always + print(self.is_self_hosted) + index_type = "HNSW" if self.is_self_hosted else "AUTOINDEX" + params = {"M": 8, "efConstruction": 64} if self.is_self_hosted else {} + self.index_params = { + "metric_type": metric_type, + "index_type": index_type, + "params": params, + } + try: + self.collection.create_index( + self.vector_field, + index_params=self.index_params, + using=self.alias, + timeout=self.timeout, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + except MilvusException as ex: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise ex + else: + logger.debug( + "Index exists already for collection: %s", self.collection_name + ) + + def _get_index(self): + """Return the index dict if index exists.""" + for index in self.collection.indexes: + if index.field_name == self.vector_field: + return index + return None + + def _create_default_search_params(self) -> None: + """Generate search params based on the current index type""" + index = self._get_index().to_dict() + if index is not None: + index_type = index["index_param"]["index_type"] + metric_type = index["index_param"]["metric_type"] + self.default_search_params = DEFAULT_SEARCH_PARAMS[index_type] + self.default_search_params["metric_type"] = metric_type + + def _load(self): + """Loads the collection.""" + if self._get_index() is not None: + if self.is_self_hosted: + try: + self.collection.load(replica_number=self.replica_number) + logger.debug( + "Collection loaded: %s", + self.collection_name, + ) + # If the replica count is incorrect, release the collection + except MilvusException: + try: + self.collection.release(timeout=self.timeout) + self.collection.load(replica_number=self.replica_number) + logger.debug( + "Successfully reloaded collection: %s", + self.collection_name, + ) + except MilvusException as ex: + logger.error( + "Failed to load collection: %s", + self.collection_name, + ) + raise ex + else: + try: + self.collection.load(replica_number=1) + logger.debug( + "Collection loaded: %s", + self.collection_name, + ) + # If both loads fail, raise exception + except MilvusException as ex: + logger.error("Failed to load collection: %s", self.collection_name) + raise ex diff --git a/pymilvus/milvus_client/milvus_client_tests.py b/pymilvus/milvus_client/milvus_client_tests.py new file mode 100644 index 000000000..83823267f --- /dev/null +++ b/pymilvus/milvus_client/milvus_client_tests.py @@ -0,0 +1,355 @@ +"""Test the MilvusClient""" +import logging +import random +import sys +from uuid import uuid4 +import numpy as np + +from pymilvus import ( + FieldSchema, + DataType, + CollectionSchema, + connections, + utility, + Collection, +) +from pymilvus.milvus_client.milvus_client import MilvusClient + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) +MILVUS_URI = None +COLLECTION_NAME = "test" + + +def valid_data(seed: int): + "Generate valid data" + datas = [] + count = 10 + for cur in range(count): + float_num = seed + (cur / 10) + int_num = (seed * 10) + cur + temp = { + "varchar": str(float_num)[:5], + "float": np.float32(float_num), + "int": int_num, + "float_vector": [float_num] * 3, + } + datas.append(temp) + + return datas + + +def invalid_data(seed: int): + """Generate wrong keyed data""" + datas = [] + count = 10 + for cur in range(count): + float_num = seed + (cur / 10) + int_num = (seed * 10) + cur + temp = { + "varcha": str(float_num)[:5], + "floa": np.float32(float_num), + "in": int_num, + "float_vecto": [float_num] * 3, + } + datas.append(temp) + + return datas + + +def create_existing_collection(uri, collection_name): + alias = uuid4().hex + connections.connect(uri=uri, alias=alias) + if utility.has_collection(collection_name=collection_name, using=alias): + utility.drop_collection(collection_name=collection_name, using=alias) + fields = [ + FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema(name="int", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65_535), + ] + schema = CollectionSchema(fields) + + ret = { + "col": Collection(collection_name, schema, using=alias), + "fields": ["float_vector", "int", "float", "varchar"], + "primary_field": "int", + "vector_field": "float_vector", + } + + return ret + + +class TestMilvusClient: + """ + Tests to Run: + Construct non existant collection + Construct existant collection + Insert data existant collection + Insert data nonexistant collection + Insert non matching data existant collection + Insert non matching data nonexistant collection + Insert insert data into auto_id with pk field + insert data into auto_id without pk field + Test Search + Test Query + Test get vector + test delete vector + test add partition + test remove partition + """ + + @staticmethod + def test_construct_from_existing_collection(): + info = create_existing_collection(MILVUS_URI, COLLECTION_NAME) + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + assert list(client.fields.keys()) == info["fields"] + assert client.pk_field == info["primary_field"] + assert client.vector_field == info["vector_field"] + + @staticmethod + def test_construct_from_nonexistant_collection(): + client = MilvusClient( + collection_name=COLLECTION_NAME, uri=MILVUS_URI, overwrite=True + ) + assert client.fields is None + assert client.pk_field is None + assert client.vector_field is None + + @staticmethod + def test_insert_in_existing_collection_valid(): + create_existing_collection(MILVUS_URI, COLLECTION_NAME) + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + client.insert_data(valid_data(1)) + assert len(client) == 10 + + @staticmethod + def test_insert_in_nonexistant_collection_valid(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + vector_field="float_vector", + overwrite=True, + ) + client.insert_data(valid_data(1)) + assert len(client) == 10 + + @staticmethod + def test_insert_in_existing_collection_invalid(): + create_existing_collection(MILVUS_URI, COLLECTION_NAME) + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + try: + client.insert_data(invalid_data(1)) + raise ValueError("Failed") + except KeyError: + return + + @staticmethod + def test_insert_in_nonexistant_collection_invalid(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + vector_field="float_vector", + overwrite=True, + ) + try: + client.insert_data(invalid_data(1)) + raise AssertionError("Failed") + except ValueError: + return + + @staticmethod + def test_insert_pk_with_autoid(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + vector_field="float_vector", + overwrite=True, + ) + client.insert_data(valid_data(1)) + pk = client.pk_field + data = valid_data(2) + for d in data: + d[pk] = int(random.random() * 100) + client.insert_data(data) + assert len(client) == 20 + + @staticmethod + def test_insert_with_missing_pk_without_autoid(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + ) + client.insert_data(valid_data(1)) + data = valid_data(2) + for d in data: + d.pop("int") + try: + client.insert_data(data) + raise ValueError("Failed") + except KeyError: + return + + @staticmethod + def test_custom_index_existing(): + col_info = create_existing_collection(MILVUS_URI, COLLECTION_NAME) + col: Collection = col_info["col"] + col.create_index( + field_name=col_info["vector_field"], + index_params={ + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, + }, + ) + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + consistency_level="Session", + index_params={ + "index_type": "IVF_SQ8", + "metric_type": "L2", + "params": {"nlist": 128}, + }, + ) + + assert client.collection.indexes[0].params["index_type"] == "IVF_SQ8" + assert client.default_search_params["params"] == {"nprobe": 10} + + @staticmethod + def test_search_default_params(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + ) + client.insert_data(valid_data(1)) + res = client.search_data([0, 0, 0], top_k=3) + assert len(res[0]) == 3 + assert res[0][0]["data"]["int"] == 10 + + @staticmethod + def test_query(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + ) + client.insert_data(valid_data(1)) + res = client.query_data('varchar in ["1.1"]') + assert res[0]["int"] == 11 + + @staticmethod + def test_delete_by_pk(): + # Test int pk + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + ) + client.insert_data(valid_data(1)) + res = client.query_data('varchar in ["1.1"]') + key = res[0]["int"] + client.delete_by_pk(key) + res = client.query_data('varchar in ["1.1"]') + assert len(res) == 0 + # Test varchar pk + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="varchar", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + ) + client.insert_data(valid_data(1)) + res = client.query_data('varchar in ["1.1"]') + key = res[0]["varchar"] + client.delete_by_pk(key) + res = client.query_data('varchar in ["1.1"]') + assert len(res) == 0 + + @staticmethod + def test_get_vector_by_pk(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + ) + client.insert_data(valid_data(1)) + res = client.query_data('varchar in ["1.1"]') + vector = client.get_vectors_by_pk(res[0]["int"]) + assert str(vector[0]["float_vector"]) == "[1.1, 1.1, 1.1]" + + @staticmethod + def test_partition_modification(): + client = MilvusClient( + collection_name=COLLECTION_NAME, + uri=MILVUS_URI, + pk_field="int", + vector_field="float_vector", + overwrite=True, + consistency_level="Session", + partitions="2", + ) + client.insert_data(valid_data(1)) + assert len(client.collection.partitions) == 2 + client.add_partitions(["lol"]) + assert len(client.collection.partitions) == 3 + client.delete_partitions(["lol"]) + assert len(client.collection.partitions) == 2 + + +if __name__ == "__main__": + MILVUS_URI = "http://localhost:19530" + # TestMilvusClient.test_construct_from_existing_collection() + # TestMilvusClient.test_construct_from_nonexistant_collection() + # TestMilvusClient.test_insert_in_existing_collection_valid() + # TestMilvusClient.test_insert_in_nonexistant_collection_valid() + # TestMilvusClient.test_insert_in_existing_collection_invalid() + # TestMilvusClient.test_insert_in_nonexistant_collection_invalid() + # TestMilvusClient.test_insert_pk_with_autoid() + # TestMilvusClient.test_insert_with_missing_pk_without_autoid() + # TestMilvusClient.test_search_default_params() + # TestMilvusClient.test_custom_index_existing() + # TestMilvusClient.test_query() + # TestMilvusClient.test_delete_by_pk() + # TestMilvusClient.test_get_vector_by_pk() + # TestMilvusClient.test_partition_modification() + + # MILVUS_URI = "https://username:pass@in01-bbd08105d3a44f9.aws-us-west-2.vectordb.zillizcloud.com:19538" + # TestMilvusClient.test_construct_from_existing_collection() + # TestMilvusClient.test_construct_from_nonexistant_collection() + # TestMilvusClient.test_insert_in_existing_collection_valid() + # TestMilvusClient.test_insert_in_nonexistant_collection_valid() + # TestMilvusClient.test_insert_in_existing_collection_invalid() + # TestMilvusClient.test_insert_in_nonexistant_collection_invalid() + # TestMilvusClient.test_insert_pk_with_autoid() + # TestMilvusClient.test_insert_with_missing_pk_without_autoid() + # TestMilvusClient.test_search_default_params() + # # TestMilvusClient.test_custom_index_existing() + # TestMilvusClient.test_query() + # TestMilvusClient.test_delete_by_pk() + # TestMilvusClient.test_get_vector_by_pk() + # # TestMilvusClient.test_partition_modification() diff --git a/requirements.txt b/requirements.txt index dd32fd799..4657a6030 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,7 @@ sphinxcontrib-qthelp sphinxcontrib-serializinghtml sphinxcontrib-napoleon sphinxcontrib-prettyspecialmethods +tqdm==4.65.0 pytest>=5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4