Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an utility to get server type #1381

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

from .utils import (
check_invalid_binary_vector,
len_of
len_of,
get_server_type,
)

from ..settings import DefaultConfig as config
Expand Down Expand Up @@ -192,6 +193,9 @@ def server_address(self):
""" Server network address """
return self._address

def get_server_type(self):
return get_server_type(self.server_address)

def reset_password(self, user, old_password, new_password, timeout=None):
"""
reset password and then setup the grpc channel.
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def name(self):
def handler(self):
return self._handler

def get_server_type(self):
return self._handler.get_server_type()

def reset_password(self, user, old_password, new_password):
self._handler.reset_password(user, old_password, new_password)

Expand Down
26 changes: 23 additions & 3 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime

from urllib.parse import urlparse

from .types import DataType
from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK
from ..exceptions import ParamError, MilvusException
Expand Down Expand Up @@ -167,7 +169,7 @@ def traverse_info(fields_info, entities):
if field_name == entity_name:
if field_type != entity_type:
raise ParamError(message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}")
f", but entities field type is {entity_type}")

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
Expand All @@ -176,11 +178,11 @@ def traverse_info(fields_info, entities):

if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}")
f", but entities field dim is {entity_dim}")

if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}")
f", but entities field dim is {entity_dim * 8}")

location[field["name"]] = i
match_flag = True
Expand All @@ -191,3 +193,21 @@ def traverse_info(fields_info, entities):
message=f"Field {field['name']} don't match in entities")

return location, primary_key_loc, auto_id_loc


def get_protocol_and_domain(host):
o = urlparse(host)
return o.scheme, o.hostname


def get_server_type(host):
protocol, hostname = get_protocol_and_domain(host)
if protocol != "https":
return "milvus"
splits = hostname.split('.')
len_of_splits = len(splits)
if len_of_splits >= 2 and \
splits[len_of_splits - 2].lower() == "zillizcloud" and \
splits[len_of_splits - 1].lower() == "com":
return "zilliz"
return "milvus"
14 changes: 14 additions & 0 deletions pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def transfer_replica(source_group, target_group, collection_name, num_replicas,
"""
return _get_connection(using).transfer_replica(source_group, target_group, collection_name, num_replicas, timeout)


def flush_all(using="default", timeout=None, **kwargs):
""" Flush all collections. All insertions, deletions, and upserts before `flush_all` will be synced.

Expand Down Expand Up @@ -1064,3 +1065,16 @@ def flush_all(using="default", timeout=None, **kwargs):
>>> future.done() # flush_all finished
"""
return _get_connection(using).flush_all(timeout=timeout, **kwargs)


def get_server_type(using="default"):
""" Get the server type. Now, it will return "zilliz" if the connection related to an instance on the zilliz cloud,
otherwise "milvus" will be returned.

:param using: Alias to the connection. Default connection is used if this is not specified.
:type: str

:return: The server type.
:rtype: str
"""
return _get_connection(using).get_server_type()
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pymilvus.client import utils


class TestUtils:
def test_get_server_type(self):
url1 = 'in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url1) == "milvus"

url2 = 'https://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url2) == "zilliz"

url3 = 'http://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url3) == "milvus"

url4 = 'https://something.notzillizcloud.com'
assert utils.get_server_type(url4) == "milvus"

url5 = 'https://something.zillizcloud.not.com'
assert utils.get_server_type(url5) == "milvus"