Skip to content

Commit

Permalink
enhance: Update the template expression proto to improve transmission…
Browse files Browse the repository at this point in the history
… efficiency (milvus-io#2334)

milvus pr: milvus-io/milvus#37484

Signed-off-by: Cai Zhang <[email protected]>
Signed-off-by: NamCaoHai <[email protected]>
  • Loading branch information
xiaocai2333 authored and CaoHaiNam committed Nov 7, 2024
1 parent b20e842 commit 518e18e
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 458 deletions.
57 changes: 38 additions & 19 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,12 +808,48 @@ def _prepare_placeholder_str(cls, data: Any):

@classmethod
def prepare_expression_template(cls, values: Dict) -> Any:
def all_elements_same_type(lst: List):
return all(isinstance(item, type(lst[0])) for item in lst)

def add_array_data(v: List) -> schema_types.TemplateArrayValue:
data = schema_types.TemplateArrayValue()
if len(v) == 0:
return data
element_type = (
infer_dtype_by_scalar_data(v[0]) if all_elements_same_type(v) else schema_types.JSON
)
if element_type in (schema_types.Bool,):
data.bool_data.data.extend(v)
return data
if element_type in (
schema_types.Int8,
schema_types.Int16,
schema_types.Int32,
schema_types.Int64,
):
data.long_data.data.extend(v)
return data
if element_type in (schema_types.Float, schema_types.Double):
data.double_data.data.extend(v)
return data
if element_type in (schema_types.VarChar, schema_types.String):
data.string_data.data.extend(v)
return data
if element_type in (schema_types.Array,):
for e in v:
data.array_data.data.append(add_array_data(e))
return data
if element_type in (schema_types.JSON,):
for e in v:
data.json_data.data.append(entity_helper.convert_to_json(e))
return data
raise ParamError(message=f"Unsupported element type: {element_type}")

def add_data(v: Any) -> schema_types.TemplateValue:
dtype = infer_dtype_by_scalar_data(v)
data = schema_types.TemplateValue()
if dtype in (schema_types.Bool,):
data.bool_val = v
data.type = schema_types.Bool
return data
if dtype in (
schema_types.Int8,
Expand All @@ -822,38 +858,21 @@ def add_data(v: Any) -> schema_types.TemplateValue:
schema_types.Int64,
):
data.int64_val = v
data.type = schema_types.Int64
return data
if dtype in (schema_types.Float, schema_types.Double):
data.float_val = v
data.type = schema_types.Double
return data
if dtype in (schema_types.VarChar, schema_types.String):
data.string_val = v
data.type = schema_types.VarChar
return data
if dtype in (schema_types.Array,):
element_datas = schema_types.TemplateArrayValue()
same_type = True
element_type = None
for element in v:
rdata = add_data(element)
element_datas.array.append(rdata)
if element_type is None:
element_type = rdata.type
elif element_type != rdata.type:
same_type = False
element_datas.element_type = element_type if same_type else schema_types.JSON
element_datas.same_type = same_type
data.array_val.CopyFrom(element_datas)
data.type = schema_types.Array
data.array_val.CopyFrom(add_array_data(v))
return data
raise ParamError(message=f"Unsupported element type: {dtype}")

expression_template_values = {}
for k, v in values.items():
expression_template_values[k] = add_data(v)

return expression_template_values

@classmethod
Expand Down
36 changes: 18 additions & 18 deletions pymilvus/grpc_gen/common_pb2.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pymilvus/grpc_gen/common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class MsgType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
Upsert: _ClassVar[MsgType]
ManualFlush: _ClassVar[MsgType]
FlushSegment: _ClassVar[MsgType]
CreateSegment: _ClassVar[MsgType]
Search: _ClassVar[MsgType]
SearchResult: _ClassVar[MsgType]
GetIndexState: _ClassVar[MsgType]
Expand Down Expand Up @@ -453,6 +454,7 @@ ResendSegmentStats: MsgType
Upsert: MsgType
ManualFlush: MsgType
FlushSegment: MsgType
CreateSegment: MsgType
Search: MsgType
SearchResult: MsgType
GetIndexState: MsgType
Expand Down
724 changes: 364 additions & 360 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

26 changes: 22 additions & 4 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,34 @@ class DropCollectionRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ...) -> None: ...

class AlterCollectionRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "collectionID", "properties")
__slots__ = ("base", "db_name", "collection_name", "collectionID", "properties", "delete_keys")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTIONID_FIELD_NUMBER: _ClassVar[int]
PROPERTIES_FIELD_NUMBER: _ClassVar[int]
DELETE_KEYS_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
collectionID: int
properties: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., collectionID: _Optional[int] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...
delete_keys: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., collectionID: _Optional[int] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., delete_keys: _Optional[_Iterable[str]] = ...) -> None: ...

class AlterCollectionFieldRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "field_name", "properties")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
FIELD_NAME_FIELD_NUMBER: _ClassVar[int]
PROPERTIES_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
field_name: str
properties: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., field_name: _Optional[str] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class HasCollectionRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "time_stamp")
Expand Down Expand Up @@ -531,18 +547,20 @@ class CreateIndexRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., field_name: _Optional[str] = ..., extra_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., index_name: _Optional[str] = ...) -> None: ...

class AlterIndexRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "index_name", "extra_params")
__slots__ = ("base", "db_name", "collection_name", "index_name", "extra_params", "delete_keys")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
INDEX_NAME_FIELD_NUMBER: _ClassVar[int]
EXTRA_PARAMS_FIELD_NUMBER: _ClassVar[int]
DELETE_KEYS_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
index_name: str
extra_params: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., index_name: _Optional[str] = ..., extra_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...
delete_keys: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., index_name: _Optional[str] = ..., extra_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., delete_keys: _Optional[_Iterable[str]] = ...) -> None: ...

class DescribeIndexRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "field_name", "index_name", "timestamp")
Expand Down
17 changes: 10 additions & 7 deletions pymilvus/grpc_gen/rg_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions pymilvus/grpc_gen/rg_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import common_pb2 as _common_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
Expand All @@ -17,14 +18,22 @@ class ResourceGroupTransfer(_message.Message):
resource_group: str
def __init__(self, resource_group: _Optional[str] = ...) -> None: ...

class ResourceGroupNodeFilter(_message.Message):
__slots__ = ("node_labels",)
NODE_LABELS_FIELD_NUMBER: _ClassVar[int]
node_labels: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, node_labels: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class ResourceGroupConfig(_message.Message):
__slots__ = ("requests", "limits", "transfer_from", "transfer_to")
__slots__ = ("requests", "limits", "transfer_from", "transfer_to", "node_filter")
REQUESTS_FIELD_NUMBER: _ClassVar[int]
LIMITS_FIELD_NUMBER: _ClassVar[int]
TRANSFER_FROM_FIELD_NUMBER: _ClassVar[int]
TRANSFER_TO_FIELD_NUMBER: _ClassVar[int]
NODE_FILTER_FIELD_NUMBER: _ClassVar[int]
requests: ResourceGroupLimit
limits: ResourceGroupLimit
transfer_from: _containers.RepeatedCompositeFieldContainer[ResourceGroupTransfer]
transfer_to: _containers.RepeatedCompositeFieldContainer[ResourceGroupTransfer]
def __init__(self, requests: _Optional[_Union[ResourceGroupLimit, _Mapping]] = ..., limits: _Optional[_Union[ResourceGroupLimit, _Mapping]] = ..., transfer_from: _Optional[_Iterable[_Union[ResourceGroupTransfer, _Mapping]]] = ..., transfer_to: _Optional[_Iterable[_Union[ResourceGroupTransfer, _Mapping]]] = ...) -> None: ...
node_filter: ResourceGroupNodeFilter
def __init__(self, requests: _Optional[_Union[ResourceGroupLimit, _Mapping]] = ..., limits: _Optional[_Union[ResourceGroupLimit, _Mapping]] = ..., transfer_from: _Optional[_Iterable[_Union[ResourceGroupTransfer, _Mapping]]] = ..., transfer_to: _Optional[_Iterable[_Union[ResourceGroupTransfer, _Mapping]]] = ..., node_filter: _Optional[_Union[ResourceGroupNodeFilter, _Mapping]] = ...) -> None: ...
Loading

0 comments on commit 518e18e

Please sign in to comment.