Skip to content

Commit

Permalink
fix: remove limitation clustering key can not be primary key
Browse files Browse the repository at this point in the history
Signed-off-by: wayblink <[email protected]>
  • Loading branch information
wayblink committed Jul 23, 2024
1 parent 93be83a commit c232145
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 10 deletions.
1 change: 0 additions & 1 deletion pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ class ExceptionsMessage:
"Ambiguous parameter, either ids or filter should be specified, cannot support both."
)
JSONKeyMustBeStr = "JSON key must be str."
ClusteringKeyNotPrimary = "Clustering key field should not be primary field"
ClusteringKeyType = (
"Clustering key field type must be DataType.INT8, DataType.INT16, "
"DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE, "
Expand Down
12 changes: 3 additions & 9 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,8 @@ def validate_partition_key(
)


def validate_clustering_key(
clustering_key_field_name: Any, clustering_key_field: Any, primary_field_name: Any
):
def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field: Any):
if clustering_key_field is not None:
if clustering_key_field.name == primary_field_name:
raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyNotPrimary)
if clustering_key_field.dtype not in [
DataType.INT8,
DataType.INT16,
Expand All @@ -82,7 +78,7 @@ def validate_clustering_key(
raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyType)
elif clustering_key_field_name is not None:
raise ClusteringKeyException(
message=ExceptionsMessage.PartitionKeyFieldNotExist % clustering_key_field_name
message=ExceptionsMessage.ClusteringKeyFieldNotExist % clustering_key_field_name
)


Expand Down Expand Up @@ -171,9 +167,7 @@ def _check_fields(self):
validate_partition_key(
partition_key_field_name, self._partition_key_field, self._primary_field.name
)
validate_clustering_key(
clustering_key_field_name, self._clustering_key_field, self._primary_field.name
)
validate_clustering_key(clustering_key_field_name, self._clustering_key_field)

auto_id = self._kwargs.get("auto_id", False)
if auto_id:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_create_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,45 @@ def test_create_bf16_collection(self, collection_name):
return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"

def test_create_clustering_key_collection(self, collection_name):
id_field = {
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
"is_clustering_key": True,
}
vector_field = {
"name": "embedding",
"type": DataType.FLOAT_VECTOR,
"metric_type": "L2",
"params": {"dim": "4"},
}
fields = {"fields": [id_field, vector_field]}
future = self._milvus.create_collection(
collection_name=collection_name, fields=fields, _async=True
)

invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary(
self._servicer.methods_by_name["CreateCollection"]
)
rpc.send_initial_metadata(())
rpc.terminate(
common_pb2.Status(
code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success"
),
(),
grpc.StatusCode.OK,
"",
)

request_schema = schema_pb2.CollectionSchema()
request_schema.ParseFromString(request.schema)

assert request.collection_name == collection_name
assert Fields.equal(request_schema.fields, fields["fields"])

return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"

0 comments on commit c232145

Please sign in to comment.