Skip to content

Commit

Permalink
Bulkinsert writer
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo committed Sep 5, 2023
1 parent c4dbd5e commit eb2a5b5
Show file tree
Hide file tree
Showing 11 changed files with 958 additions and 0 deletions.
197 changes: 197 additions & 0 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import os
import json
import random
import threading

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("example_bulkwriter")

from pymilvus import (
connections,
FieldSchema, CollectionSchema, DataType,
Collection,
utility,
LocalBulkWriter,
RemoteBulkWriter,
BulkFileType,
bulk_import,
get_import_progress,
list_import_jobs,
)

# minio
MINIO_ADDRESS = "0.0.0.0:9000"
MINIO_SECRET_KEY = "minioadmin"
MINIO_ACCESS_KEY = "minioadmin"

# milvus
HOST = '127.0.0.1'
PORT = '19530'

COLLECTION_NAME = "test_abc"
DIM = 256

def create_connection():
print(f"\nCreate connection...")
connections.connect(host=HOST, port=PORT)
print(f"\nConnected")


def build_collection():
if utility.has_collection(COLLECTION_NAME):
utility.drop_collection(COLLECTION_NAME)

field1 = FieldSchema(name="id", dtype=DataType.INT64, auto_id=True, is_primary=True)
field2 = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=DIM)
field3 = FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=100)
schema = CollectionSchema(fields=[field1, field2, field3])
collection = Collection(name=COLLECTION_NAME, schema=schema)
print("Collection created")
return collection.schema

def test_local_writer_json(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema,
local_path="/tmp/bulk_data",
segment_size=4*1024*1024,
file_type=BulkFileType.JSON_RB,
)
for i in range(10):
local_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

local_writer.commit()
print("test local writer done!")
print(local_writer.data_path)
return local_writer.data_path

def test_local_writer_npy(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema, local_path="/tmp/bulk_data", segment_size=4*1024*1024)
for i in range(10000):
local_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

local_writer.commit()
print("test local writer done!")
print(local_writer.data_path)
return local_writer.data_path


def _append_row(writer: LocalBulkWriter, begin: int, end: int):
for i in range(begin, end):
writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

def test_parallel_append(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema,
local_path="/tmp/bulk_data",
segment_size=1000 * 1024 * 1024,
file_type=BulkFileType.JSON_RB,
)
threads = []
thread_count = 100
rows_per_thread = 1000
for k in range(thread_count):
x = threading.Thread(target=_append_row, args=(local_writer, k*rows_per_thread, (k+1)*rows_per_thread,))
threads.append(x)
x.start()
print(f"Thread '{x.name}' started")

for th in threads:
th.join()

local_writer.commit()
print(f"Append finished, {thread_count*rows_per_thread} rows")
file_path = os.path.join(local_writer.data_path, "1.json")
with open(file_path, 'r') as file:
data = json.load(file)

print("Verify the output content...")
rows = data['rows']
assert len(rows) == thread_count*rows_per_thread
for i in range(len(rows)):
row = rows[i]
assert row['desc'] == f"description_{row['id']}"


def test_remote_writer(schema: CollectionSchema):
remote_writer = RemoteBulkWriter(schema=schema,
remote_path="bulk_data",
local_path="/tmp/bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
endpoint=MINIO_ADDRESS,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
segment_size=50 * 1024 * 1024,
)

for i in range(10000):
if i % 1000 == 0:
logger.info(f"{i} rows has been append to remote writer")
remote_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

remote_writer.commit()
print("test remote writer done!")
print(remote_writer.data_path)
return remote_writer.data_path


def test_cloud_bulkinsert():
url = "https://_your_cloud_server_url_"
cluster_id = "_your_cloud_instance_id_"

print(f"===================== import files to cloud vectordb ====================")
object_url = "_your_object_storage_service_url_"
object_url_access_key = "_your_object_storage_service_access_key_"
object_url_secret_key = "_your_object_storage_service_secret_key_"
resp = bulk_import(
url=url,
object_url=object_url,
access_key=object_url_access_key,
secret_key=object_url_secret_key,
cluster_id=cluster_id,
collection_name=COLLECTION_NAME,
)
print(resp)

print(f"===================== get import job progress ====================")
job_id = resp['data']['jobId']
resp = get_import_progress(
url=url,
job_id=job_id,
cluster_id=cluster_id,
)
print(resp)

print(f"===================== list import jobs ====================")
resp = list_import_jobs(
url=url,
cluster_id=cluster_id,
page_size=10,
current_page=1,
)
print(resp)


if __name__ == '__main__':
create_connection()
schema = build_collection()

test_local_writer_json(schema)
test_local_writer_npy(schema)
test_remote_writer(schema)
test_parallel_append(schema)

# test_cloud_bulkinsert()

22 changes: 22 additions & 0 deletions pymilvus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

from .bulk_writer.bulk_import import (
bulk_import,
get_import_progress,
list_import_jobs,
)

# bulk writer
from .bulk_writer.constants import (
BulkFileType,
)
from .bulk_writer.local_bulk_writer import (
LocalBulkWriter,
)
from .bulk_writer.remote_bulk_writer import (
RemoteBulkWriter,
)
from .client import __version__
from .client.prepare import Prepare
from .client.stub import Milvus
Expand Down Expand Up @@ -124,4 +140,10 @@
"ResourceGroupInfo",
"Connections",
"IndexType",
"BulkFileType",
"LocalBulkWriter",
"RemoteBulkWriter",
"bulk_import",
"get_import_progress",
"list_import_jobs",
]
Empty file.
148 changes: 148 additions & 0 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import json
import logging
from pathlib import Path

import numpy as np

from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema

from .constants import (
DYNAMIC_FIELD_NAME,
BulkFileType,
)

logger = logging.getLogger("bulk_buffer")
logger.setLevel(logging.DEBUG)


class Buffer:
def __init__(
self,
schema: CollectionSchema,
file_type: BulkFileType = BulkFileType.NPY,
):
self._buffer = {}
self._file_type = file_type
for field in schema.fields:
self._buffer[field.name] = []

if len(self._buffer) == 0:
self._throw("Illegal collection schema: fields list is empty")

# dynamic field, internal name is '$meta'
if schema.enable_dynamic_field:
self._buffer[DYNAMIC_FIELD_NAME] = []

@property
def row_count(self) -> int:
if len(self._buffer) == 0:
return 0

for k in self._buffer:
return len(self._buffer[k])
return None

def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def append_row(self, row: dict):
dynamic_values = {}
if DYNAMIC_FIELD_NAME in row and not isinstance(row[DYNAMIC_FIELD_NAME], dict):
self._throw(f"Dynamic field '{DYNAMIC_FIELD_NAME}' value should be JSON format")

for k in row:
if k == DYNAMIC_FIELD_NAME:
dynamic_values.update(row[k])
continue

if k not in self._buffer:
dynamic_values[k] = row[k]
else:
self._buffer[k].append(row[k])

if DYNAMIC_FIELD_NAME in self._buffer:
self._buffer[DYNAMIC_FIELD_NAME].append(json.dumps(dynamic_values))

def persist(self, local_path: str) -> list:
# verify row count of fields are equal
row_count = -1
for k in self._buffer:
if row_count < 0:
row_count = len(self._buffer[k])
elif row_count != len(self._buffer[k]):
self._throw(
"Column `{}` row count {} doesn't equal to the first column row count {}".format(
k, len(self._buffer[k]), row_count
)
)

# output files
if self._file_type == BulkFileType.NPY:
return self._persist_npy(local_path)
if self._file_type == BulkFileType.JSON_RB:
return self._persist_json_rows(local_path)

self._throw(f"Unsupported file tpye: {self._file_type}")
return []

def _persist_npy(self, local_path: str):
Path(local_path).mkdir(exist_ok=True)

file_list = []
for k in self._buffer:
full_file_name = Path(local_path).joinpath(k + ".npy")
file_list.append(full_file_name)
try:
np.save(full_file_name, self._buffer[k])
except Exception as e:
self._throw(f"Failed to persist column-based file {full_file_name}, error: {e}")

logger.info(f"Successfully persist column-based file {full_file_name}")

if len(file_list) != len(self._buffer):
logger.error("Some of fields were not persisted successfully, abort the files")
for f in file_list:
Path(f).unlink()
Path(local_path).rmdir()
file_list.clear()
self._throw("Some of fields were not persisted successfully, abort the files")

return file_list

def _persist_json_rows(self, local_path: str):
rows = []
row_count = len(next(iter(self._buffer.values())))
row_index = 0
while row_index < row_count:
row = {}
for k, v in self._buffer.items():
row[k] = v[row_index]
rows.append(row)
row_index = row_index + 1

data = {
"rows": rows,
}
file_path = Path(local_path + ".json")
try:
with file_path.open("w") as json_file:
json.dump(data, json_file, indent=2)
except Exception as e:
self._throw(f"Failed to persist row-based file {file_path}, error: {e}")

logger.info(f"Successfully persist row-based file {file_path}")
return [file_path]
Loading

0 comments on commit eb2a5b5

Please sign in to comment.