Skip to content

Commit

Permalink
last few fixes (#1408)
Browse files Browse the repository at this point in the history
Signed-off-by: Filip Haltmayer <[email protected]>
  • Loading branch information
filip-halt authored May 10, 2023
1 parent 29d25b5 commit a5c79ba
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
6 changes: 5 additions & 1 deletion pymilvus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
from .orm.future import SearchFuture, MutationFuture
from .orm.role import Role

from .milvus_client.milvus_client import MilvusClient

__all__ = [
'Collection', 'Index', 'Partition',
'connections',
Expand All @@ -91,5 +93,7 @@

'Milvus', 'Prepare', 'Status', 'DataType',
'MilvusException',
'__version__'
'__version__',

'MilvusClient'
]
21 changes: 13 additions & 8 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Dict, List, Union
from uuid import uuid4

from tqdm import tqdm

from pymilvus.exceptions import MilvusException
from pymilvus.milvus_client.defaults import DEFAULT_SEARCH_PARAMS
from pymilvus.orm import utility
Expand All @@ -20,17 +18,17 @@
class MilvusClient:
"""The Milvus Client"""

# pylint: disable=logging-too-many-args, too-many-instance-attributes
# pylint: disable=logging-too-many-args, too-many-instance-attributes, import-outside-toplevel

def __init__(
self,
collection_name: str = "ClientCollection",
pk_field: str = None,
vector_field: str = None,
uri: str = None,
uri: str = "http://localhost:19530",
shard_num: int = None,
partitions: List[str] = None,
consistency_level: str = "Bounded",
consistency_level: str = "Session",
replica_number: int = 1,
index_params: dict = None,
timeout: int = None,
Expand Down Expand Up @@ -70,6 +68,14 @@ def __init__(
overwrite (bool, optional): Whether to overwrite existing collection if exists.
Defaults to False
"""
# Optionial TQDM import
try:
import tqdm
self.tqdm = tqdm.tqdm
except ImportError:
logger.debug("tqdm not found")
self.tqdm = (lambda x, disable: x)

self.uri = uri
self.collection_name = collection_name
self.shard_num = shard_num
Expand Down Expand Up @@ -166,8 +172,7 @@ def insert_data(
if key in self.fields:
insert_dict.setdefault(key, []).append(value)

# Insert the data in batches
for i in tqdm(range(0, len(data), batch_size), disable=not progress_bar):
for i in self.tqdm(range(0, len(data), batch_size), disable=not progress_bar):
# Convert dict to list of lists batch for insertion
try:
insert_batch = [
Expand Down Expand Up @@ -379,7 +384,7 @@ def get_vectors_by_pk(
self,
pks: Union[list, str, int],
timeout: int = None,
) -> None:
) -> List[List[float]]:
"""Grab the inserted vectors using the primary key from the Collection.
Due to current implementations, grabbing a large amount of vectors is slow.
Expand Down

0 comments on commit a5c79ba

Please sign in to comment.