diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 5ae5e9a83..5400c6530 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -38,7 +38,8 @@ from .utils import ( check_invalid_binary_vector, - len_of + len_of, + get_server_type, ) from ..settings import DefaultConfig as config @@ -192,6 +193,9 @@ def server_address(self): """ Server network address """ return self._address + def get_server_type(self): + return get_server_type(self.server_address) + def reset_password(self, user, old_password, new_password, timeout=None): """ reset password and then setup the grpc channel. diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index 0a2552083..17022c59e 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -47,6 +47,9 @@ def name(self): def handler(self): return self._handler + def get_server_type(self): + return self._handler.get_server_type() + def reset_password(self, user, old_password, new_password): self._handler.reset_password(user, old_password, new_password) diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 64ae19e76..3facee3d8 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -1,5 +1,7 @@ import datetime +from urllib.parse import urlparse + from .types import DataType from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK from ..exceptions import ParamError, MilvusException @@ -167,7 +169,7 @@ def traverse_info(fields_info, entities): if field_name == entity_name: if field_type != entity_type: raise ParamError(message=f"Collection field type is {field_type}" - f", but entities field type is {entity_type}") + f", but entities field type is {entity_type}") entity_dim, field_dim = 0, 0 if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: @@ -176,11 +178,11 @@ def traverse_info(fields_info, entities): if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim: raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}") + f", but entities field dim is {entity_dim}") if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim: raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim * 8}") + f", but entities field dim is {entity_dim * 8}") location[field["name"]] = i match_flag = True @@ -191,3 +193,21 @@ def traverse_info(fields_info, entities): message=f"Field {field['name']} don't match in entities") return location, primary_key_loc, auto_id_loc + + +def get_protocol_and_domain(host): + o = urlparse(host) + return o.scheme, o.hostname + + +def get_server_type(host): + protocol, hostname = get_protocol_and_domain(host) + if protocol != "https": + return "milvus" + splits = hostname.split('.') + len_of_splits = len(splits) + if len_of_splits >= 2 and \ + splits[len_of_splits - 2].lower() == "zillizcloud" and \ + splits[len_of_splits - 1].lower() == "com": + return "zilliz" + return "milvus" diff --git a/pymilvus/orm/utility.py b/pymilvus/orm/utility.py index 6e3ee7177..48bd230ad 100644 --- a/pymilvus/orm/utility.py +++ b/pymilvus/orm/utility.py @@ -1037,6 +1037,7 @@ def transfer_replica(source_group, target_group, collection_name, num_replicas, """ return _get_connection(using).transfer_replica(source_group, target_group, collection_name, num_replicas, timeout) + def flush_all(using="default", timeout=None, **kwargs): """ Flush all collections. All insertions, deletions, and upserts before `flush_all` will be synced. @@ -1064,3 +1065,16 @@ def flush_all(using="default", timeout=None, **kwargs): >>> future.done() # flush_all finished """ return _get_connection(using).flush_all(timeout=timeout, **kwargs) + + +def get_server_type(using="default"): + """ Get the server type. Now, it will return "zilliz" if the connection related to an instance on the zilliz cloud, + otherwise "milvus" will be returned. + + :param using: Alias to the connection. Default connection is used if this is not specified. + :type: str + + :return: The server type. + :rtype: str + """ + return _get_connection(using).get_server_type() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..aa07c223c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,21 @@ +from pymilvus.client import utils + + +class TestUtils: + def test_get_server_type(self): + url1 = 'in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com' + assert utils.get_server_type(url1) == "milvus" + + url2 = 'https://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com' + assert utils.get_server_type(url2) == "zilliz" + + url3 = 'http://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com' + assert utils.get_server_type(url3) == "milvus" + + url4 = 'https://something.notzillizcloud.com' + assert utils.get_server_type(url4) == "milvus" + + url5 = 'https://something.zillizcloud.not.com' + assert utils.get_server_type(url5) == "milvus" + +