Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify num_shards and mark shards_num as deprecated #1412

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ class Prepare:
def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, Iterable], CollectionSchema],
**kwargs) -> milvus_types.CreateCollectionRequest:
"""
:type fields: Union(Dict[str, Iterable], CollectionSchema)
:param fields: (Required)

`{"fields": [
{"name": "A", "type": DataType.INT32}
{"name": "B", "type": DataType.INT64, "auto_id": True, "is_primary": True},
{"name": "C", "type": DataType.FLOAT},
{"name": "Vec", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}]
}`

:return: milvus_types.CreateCollectionRequest
Args:
fields (Union(Dict[str, Iterable], CollectionSchema)).

{"fields": [
{"name": "A", "type": DataType.INT32}
{"name": "B", "type": DataType.INT64, "auto_id": True, "is_primary": True},
{"name": "C", "type": DataType.FLOAT},
{"name": "Vec", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}]
}

Returns:
milvus_types.CreateCollectionRequest
"""

if isinstance(fields, CollectionSchema):
schema = cls.get_schema_from_collection_schema(collection_name, fields, **kwargs)
else:
Expand All @@ -54,11 +56,15 @@ def create_collection_request(cls, collection_name: str, fields: Union[Dict[str,
properties = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items()]
req.properties.extend(properties)

shards_num = kwargs.get("shards_num")
if shards_num is not None:
if not isinstance(shards_num, int) or isinstance(shards_num, bool):
raise ParamError(message=f"invalid shards_num type, got {type(shards_num)}, expected int")
req.shards_num = shards_num
same_key = set(kwargs.keys()).intersection({"num_shards", "shards_num"})
if len(same_key) > 0:
if len(same_key) > 1:
raise ParamError(message="got both num_shards and shards_num in kwargs, expected only one of them")

num_shards = kwargs[list(same_key)[0]]
if not isinstance(num_shards, int):
raise ParamError(message=f"invalid num_shards type, got {type(num_shards)}, expected int")
req.shards_num=num_shards

num_partitions = kwargs.get("num_partitions", None)
if num_partitions is not None:
Expand Down
8 changes: 5 additions & 3 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def close(self):
self.handler.close()
self._handler = None

def create_collection(self, collection_name, fields, shards_num=None, timeout=None, **kwargs):
def create_collection(self, collection_name, fields, timeout=None, **kwargs):
""" Creates a collection.

:param collection_name: The name of the collection. A collection name can only include
Expand All @@ -83,7 +83,9 @@ def create_collection(self, collection_name, fields, shards_num=None, timeout=No
:type timeout: float

:param kwargs:
* *shards_num* (``int``) --
* *num_shards* (``int``) --
How wide to scale collection. Corresponds to how many active datanodes can be used on insert.
* *shards_num* (``int``, deprecated) --
How wide to scale collection. Corresponds to how many active datanodes can be used on insert.
* *consistency_level* (``str/int``) --
Which consistency level to use when searching in the collection. For details, see
Expand All @@ -100,7 +102,7 @@ def create_collection(self, collection_name, fields, shards_num=None, timeout=No
:raises MilvusException: If the return result from server is not ok
"""
with self._connection() as handler:
return handler.create_collection(collection_name, fields, shards_num=shards_num, timeout=timeout, **kwargs)
return handler.create_collection(collection_name, fields, timeout=timeout, **kwargs)

def drop_collection(self, collection_name, timeout=None):
"""
Expand Down
8 changes: 4 additions & 4 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
pk_field: str = None,
vector_field: str = None,
uri: str = "http://localhost:19530",
shard_num: int = None,
num_shards: int = None,
partitions: List[str] = None,
consistency_level: str = "Session",
replica_number: int = 1,
Expand All @@ -52,7 +52,7 @@ def __init__(
uri (str, optional): The connection address to use to connect to the
instance. Defaults to "http://localhost:19530". Another example:
"https://username:[email protected]:19538
shard_num (int, optional): The amount of shards to use for the collection. Unless
num_shards (int, optional): The amount of shards to use for the collection. Unless
dealing with huge scale, recommended to keep at default. Defaults to None and allows
server to set.
partitions (List[str], optional): Which paritions to create for the collection.
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(

self.uri = uri
self.collection_name = collection_name
self.shard_num = shard_num
self.num_shards = num_shards
self.partitions = partitions
self.consistency_level = consistency_level
self.replica_number = replica_number
Expand Down Expand Up @@ -725,7 +725,7 @@ def _create_collection(self, data: dict) -> None:
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
shards_num=self.shard_num,
num_shards=self.num_shards,
num_partitions=self.num_partitions,
using=self.alias,
)
Expand Down
4 changes: 3 additions & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def
using (``str``, optional): Milvus connection alias name, defaults to 'default'.
**kwargs (``dict``):

* *shards_num (``int``, optional): how many shards will the insert data be divided.
* *num_shards (``int``, optional): how many shards will the insert data be divided.
* *shards_num (``int``, optional, deprecated): how many shards will the insert data be divided.
* *consistency_level* (``int/ str``)
Which consistency level to use when searching in the collection.
Options of consistency level: Strong, Bounded, Eventually, Session, Customized.
Expand All @@ -70,6 +71,7 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def
An optional duration of time in seconds to allow for the RPCs.
If timeout is not set, the client keeps waiting until the server responds or an error occurs.


Raises:
SchemaNotReadyException: if the schema is wrong.

Expand Down
25 changes: 25 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,28 @@ def test_get_schema_from_collection_schema(self):
assert c_schema.fields[1].is_primary_key is True
assert c_schema.fields[1].autoID is True
assert len(c_schema.fields[1].type_params) == 0

@pytest.mark.parametrize("kv", [
{"shards_num": 1},
{"num_shards": 2},
])
def test_create_collection_request_num_shards(self, kv):
schema = CollectionSchema([
FieldSchema("field_vector", DataType.FLOAT_VECTOR, dim=8),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True)
])
req = Prepare.create_collection_request("c_name", schema, **kv)
assert req.shards_num == list(kv.values())[0]

@pytest.mark.parametrize("kv", [
{"shards_num": 1, "num_shards": 1},
{"num_shards": "2"},
])
def test_create_collection_request_num_shards_error(self, kv):
schema = CollectionSchema([
FieldSchema("field_vector", DataType.FLOAT_VECTOR, dim=8),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True)
])

with pytest.raises(MilvusException):
req = Prepare.create_collection_request("c_name", schema, **kv)