From 9902427f31f65cfdfdb56c1f97c7ffc8c0d6bd69 Mon Sep 17 00:00:00 2001 From: yangxuan Date: Fri, 5 May 2023 14:50:32 +0800 Subject: [PATCH] Unify num_shards and mark shards_num as deprecated The user can still use shards_num to create collection. But **num_shards** are the recommanded way to create collection and the only way to get from a collection's property See also: milvus-io/milvus#23853 Signed-off-by: yangxuan --- pymilvus/client/prepare.py | 38 ++++++++++++++----------- pymilvus/client/stub.py | 8 ++++-- pymilvus/milvus_client/milvus_client.py | 8 +++--- pymilvus/orm/collection.py | 4 ++- tests/test_prepare.py | 25 ++++++++++++++++ 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 4dd926a51..8a1ba25f5 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -26,18 +26,20 @@ class Prepare: def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, Iterable], CollectionSchema], **kwargs) -> milvus_types.CreateCollectionRequest: """ - :type fields: Union(Dict[str, Iterable], CollectionSchema) - :param fields: (Required) - - `{"fields": [ - {"name": "A", "type": DataType.INT32} - {"name": "B", "type": DataType.INT64, "auto_id": True, "is_primary": True}, - {"name": "C", "type": DataType.FLOAT}, - {"name": "Vec", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}] - }` - - :return: milvus_types.CreateCollectionRequest + Args: + fields (Union(Dict[str, Iterable], CollectionSchema)). + + {"fields": [ + {"name": "A", "type": DataType.INT32} + {"name": "B", "type": DataType.INT64, "auto_id": True, "is_primary": True}, + {"name": "C", "type": DataType.FLOAT}, + {"name": "Vec", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}] + } + + Returns: + milvus_types.CreateCollectionRequest """ + if isinstance(fields, CollectionSchema): schema = cls.get_schema_from_collection_schema(collection_name, fields, **kwargs) else: @@ -54,11 +56,15 @@ def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, properties = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items()] req.properties.extend(properties) - shards_num = kwargs.get("shards_num") - if shards_num is not None: - if not isinstance(shards_num, int) or isinstance(shards_num, bool): - raise ParamError(message=f"invalid shards_num type, got {type(shards_num)}, expected int") - req.shards_num = shards_num + same_key = set(kwargs.keys()).intersection({"num_shards", "shards_num"}) + if len(same_key) > 0: + if len(same_key) > 1: + raise ParamError(message="got both num_shards and shards_num in kwargs, expected only one of them") + + num_shards = kwargs[list(same_key)[0]] + if not isinstance(num_shards, int): + raise ParamError(message=f"invalid num_shards type, got {type(num_shards)}, expected int") + req.shards_num=num_shards num_partitions = kwargs.get("num_partitions", None) if num_partitions is not None: diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index 20fd7939b..3f5faefc9 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -59,7 +59,7 @@ def close(self): self.handler.close() self._handler = None - def create_collection(self, collection_name, fields, shards_num=None, timeout=None, **kwargs): + def create_collection(self, collection_name, fields, timeout=None, **kwargs): """ Creates a collection. :param collection_name: The name of the collection. A collection name can only include @@ -83,7 +83,9 @@ def create_collection(self, collection_name, fields, shards_num=None, timeout=No :type timeout: float :param kwargs: - * *shards_num* (``int``) -- + * *num_shards* (``int``) -- + How wide to scale collection. Corresponds to how many active datanodes can be used on insert. + * *shards_num* (``int``, deprecated) -- How wide to scale collection. Corresponds to how many active datanodes can be used on insert. * *consistency_level* (``str/int``) -- Which consistency level to use when searching in the collection. For details, see @@ -100,7 +102,7 @@ def create_collection(self, collection_name, fields, shards_num=None, timeout=No :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_collection(collection_name, fields, shards_num=shards_num, timeout=timeout, **kwargs) + return handler.create_collection(collection_name, fields, timeout=timeout, **kwargs) def drop_collection(self, collection_name, timeout=None): """ diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 923c56ab7..704d18579 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -26,7 +26,7 @@ def __init__( pk_field: str = None, vector_field: str = None, uri: str = "http://localhost:19530", - shard_num: int = None, + num_shards: int = None, partitions: List[str] = None, consistency_level: str = "Session", replica_number: int = 1, @@ -52,7 +52,7 @@ def __init__( uri (str, optional): The connection address to use to connect to the instance. Defaults to "http://localhost:19530". Another example: "https://username:password@in01-12a.aws-us-west-2.vectordb.zillizcloud.com:19538 - shard_num (int, optional): The amount of shards to use for the collection. Unless + num_shards (int, optional): The amount of shards to use for the collection. Unless dealing with huge scale, recommended to keep at default. Defaults to None and allows server to set. partitions (List[str], optional): Which paritions to create for the collection. @@ -81,7 +81,7 @@ def __init__( self.uri = uri self.collection_name = collection_name - self.shard_num = shard_num + self.num_shards = num_shards self.partitions = partitions self.consistency_level = consistency_level self.replica_number = replica_number @@ -725,7 +725,7 @@ def _create_collection(self, data: dict) -> None: name=self.collection_name, schema=schema, consistency_level=self.consistency_level, - shards_num=self.shard_num, + num_shards=self.num_shards, num_partitions=self.num_partitions, using=self.alias, ) diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 1c00458dc..fdda90af8 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -56,7 +56,8 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def using (``str``, optional): Milvus connection alias name, defaults to 'default'. **kwargs (``dict``): - * *shards_num (``int``, optional): how many shards will the insert data be divided. + * *num_shards (``int``, optional): how many shards will the insert data be divided. + * *shards_num (``int``, optional, deprecated): how many shards will the insert data be divided. * *consistency_level* (``int/ str``) Which consistency level to use when searching in the collection. Options of consistency level: Strong, Bounded, Eventually, Session, Customized. @@ -70,6 +71,7 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def An optional duration of time in seconds to allow for the RPCs. If timeout is not set, the client keeps waiting until the server responds or an error occurs. + Raises: SchemaNotReadyException: if the schema is wrong. diff --git a/tests/test_prepare.py b/tests/test_prepare.py index ef5729491..cdca5b8b0 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -113,3 +113,28 @@ def test_get_schema_from_collection_schema(self): assert c_schema.fields[1].is_primary_key is True assert c_schema.fields[1].autoID is True assert len(c_schema.fields[1].type_params) == 0 + + @pytest.mark.parametrize("kv", [ + {"shards_num": 1}, + {"num_shards": 2}, + ]) + def test_create_collection_request_num_shards(self, kv): + schema = CollectionSchema([ + FieldSchema("field_vector", DataType.FLOAT_VECTOR, dim=8), + FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True) + ]) + req = Prepare.create_collection_request("c_name", schema, **kv) + assert req.shards_num == list(kv.values())[0] + + @pytest.mark.parametrize("kv", [ + {"shards_num": 1, "num_shards": 1}, + {"num_shards": "2"}, + ]) + def test_create_collection_request_num_shards_error(self, kv): + schema = CollectionSchema([ + FieldSchema("field_vector", DataType.FLOAT_VECTOR, dim=8), + FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True) + ]) + + with pytest.raises(MilvusException): + req = Prepare.create_collection_request("c_name", schema, **kv)