diff --git a/examples/example_bulkwriter.py b/examples/example_bulkwriter.py new file mode 100644 index 000000000..3fb96eda8 --- /dev/null +++ b/examples/example_bulkwriter.py @@ -0,0 +1,197 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import os +import json +import random +import threading + +import logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("example_bulkwriter") + +from pymilvus import ( + connections, + FieldSchema, CollectionSchema, DataType, + Collection, + utility, + LocalBulkWriter, + RemoteBulkWriter, + BulkFileType, + bulk_import, + get_import_progress, + list_import_jobs, +) + +# minio +MINIO_ADDRESS = "0.0.0.0:9000" +MINIO_SECRET_KEY = "minioadmin" +MINIO_ACCESS_KEY = "minioadmin" + +# milvus +HOST = '127.0.0.1' +PORT = '19530' + +COLLECTION_NAME = "test_abc" +DIM = 256 + +def create_connection(): + print(f"\nCreate connection...") + connections.connect(host=HOST, port=PORT) + print(f"\nConnected") + + +def build_collection(): + if utility.has_collection(COLLECTION_NAME): + utility.drop_collection(COLLECTION_NAME) + + field1 = FieldSchema(name="id", dtype=DataType.INT64, auto_id=True, is_primary=True) + field2 = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=DIM) + field3 = FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=100) + schema = CollectionSchema(fields=[field1, field2, field3]) + collection = Collection(name=COLLECTION_NAME, schema=schema) + print("Collection created") + return collection.schema + +def test_local_writer_json(schema: CollectionSchema): + local_writer = LocalBulkWriter(schema=schema, + local_path="/tmp/bulk_data", + segment_size=4*1024*1024, + file_type=BulkFileType.JSON_RB, + ) + for i in range(10): + local_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"}) + + local_writer.commit() + print("test local writer done!") + print(local_writer.data_path) + return local_writer.data_path + +def test_local_writer_npy(schema: CollectionSchema): + local_writer = LocalBulkWriter(schema=schema, local_path="/tmp/bulk_data", segment_size=4*1024*1024) + for i in range(10000): + local_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"}) + + local_writer.commit() + print("test local writer done!") + print(local_writer.data_path) + return local_writer.data_path + + +def _append_row(writer: LocalBulkWriter, begin: int, end: int): + for i in range(begin, end): + writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"}) + +def test_parallel_append(schema: CollectionSchema): + local_writer = LocalBulkWriter(schema=schema, + local_path="/tmp/bulk_data", + segment_size=1000 * 1024 * 1024, + file_type=BulkFileType.JSON_RB, + ) + threads = [] + thread_count = 100 + rows_per_thread = 1000 + for k in range(thread_count): + x = threading.Thread(target=_append_row, args=(local_writer, k*rows_per_thread, (k+1)*rows_per_thread,)) + threads.append(x) + x.start() + print(f"Thread '{x.name}' started") + + for th in threads: + th.join() + + local_writer.commit() + print(f"Append finished, {thread_count*rows_per_thread} rows") + file_path = os.path.join(local_writer.data_path, "1.json") + with open(file_path, 'r') as file: + data = json.load(file) + + print("Verify the output content...") + rows = data['rows'] + assert len(rows) == thread_count*rows_per_thread + for i in range(len(rows)): + row = rows[i] + assert row['desc'] == f"description_{row['id']}" + + +def test_remote_writer(schema: CollectionSchema): + remote_writer = RemoteBulkWriter(schema=schema, + remote_path="bulk_data", + local_path="/tmp/bulk_data", + connect_param=RemoteBulkWriter.ConnectParam( + endpoint=MINIO_ADDRESS, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + bucket_name="a-bucket", + ), + segment_size=50 * 1024 * 1024, + ) + + for i in range(10000): + if i % 1000 == 0: + logger.info(f"{i} rows has been append to remote writer") + remote_writer.append({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"}) + + remote_writer.commit() + print("test remote writer done!") + print(remote_writer.data_path) + return remote_writer.data_path + + +def test_cloud_bulkinsert(): + url = "https://_your_cloud_server_url_" + cluster_id = "_your_cloud_instance_id_" + + print(f"===================== import files to cloud vectordb ====================") + object_url = "_your_object_storage_service_url_" + object_url_access_key = "_your_object_storage_service_access_key_" + object_url_secret_key = "_your_object_storage_service_secret_key_" + resp = bulk_import( + url=url, + object_url=object_url, + access_key=object_url_access_key, + secret_key=object_url_secret_key, + cluster_id=cluster_id, + collection_name=COLLECTION_NAME, + ) + print(resp) + + print(f"===================== get import job progress ====================") + job_id = resp['data']['jobId'] + resp = get_import_progress( + url=url, + job_id=job_id, + cluster_id=cluster_id, + ) + print(resp) + + print(f"===================== list import jobs ====================") + resp = list_import_jobs( + url=url, + cluster_id=cluster_id, + page_size=10, + current_page=1, + ) + print(resp) + + +if __name__ == '__main__': + create_connection() + schema = build_collection() + + test_local_writer_json(schema) + test_local_writer_npy(schema) + test_remote_writer(schema) + test_parallel_append(schema) + + # test_cloud_bulkinsert() + diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index cc3a19597..2c160fbd8 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -10,6 +10,22 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. +from .bulk_writer.bulk_import import ( + bulk_import, + get_import_progress, + list_import_jobs, +) + +# bulk writer +from .bulk_writer.constants import ( + BulkFileType, +) +from .bulk_writer.local_bulk_writer import ( + LocalBulkWriter, +) +from .bulk_writer.remote_bulk_writer import ( + RemoteBulkWriter, +) from .client import __version__ from .client.prepare import Prepare from .client.stub import Milvus @@ -124,4 +140,10 @@ "ResourceGroupInfo", "Connections", "IndexType", + "BulkFileType", + "LocalBulkWriter", + "RemoteBulkWriter", + "bulk_import", + "get_import_progress", + "list_import_jobs", ] diff --git a/pymilvus/bulk_writer/__init__.py b/pymilvus/bulk_writer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymilvus/bulk_writer/buffer.py b/pymilvus/bulk_writer/buffer.py new file mode 100644 index 000000000..5db288fec --- /dev/null +++ b/pymilvus/bulk_writer/buffer.py @@ -0,0 +1,148 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import json +import logging +from pathlib import Path + +import numpy as np + +from pymilvus.exceptions import MilvusException +from pymilvus.orm.schema import CollectionSchema + +from .constants import ( + DYNAMIC_FIELD_NAME, + BulkFileType, +) + +logger = logging.getLogger("bulk_buffer") +logger.setLevel(logging.DEBUG) + + +class Buffer: + def __init__( + self, + schema: CollectionSchema, + file_type: BulkFileType = BulkFileType.NPY, + ): + self._buffer = {} + self._file_type = file_type + for field in schema.fields: + self._buffer[field.name] = [] + + if len(self._buffer) == 0: + self._throw("Illegal collection schema: fields list is empty") + + # dynamic field, internal name is '$meta' + if schema.enable_dynamic_field: + self._buffer[DYNAMIC_FIELD_NAME] = [] + + @property + def row_count(self) -> int: + if len(self._buffer) == 0: + return 0 + + for k in self._buffer: + return len(self._buffer[k]) + return None + + def _throw(self, msg: str): + logger.error(msg) + raise MilvusException(message=msg) + + def append_row(self, row: dict): + dynamic_values = {} + if DYNAMIC_FIELD_NAME in row and not isinstance(row[DYNAMIC_FIELD_NAME], dict): + self._throw(f"Dynamic field '{DYNAMIC_FIELD_NAME}' value should be JSON format") + + for k in row: + if k == DYNAMIC_FIELD_NAME: + dynamic_values.update(row[k]) + continue + + if k not in self._buffer: + dynamic_values[k] = row[k] + else: + self._buffer[k].append(row[k]) + + if DYNAMIC_FIELD_NAME in self._buffer: + self._buffer[DYNAMIC_FIELD_NAME].append(json.dumps(dynamic_values)) + + def persist(self, local_path: str) -> list: + # verify row count of fields are equal + row_count = -1 + for k in self._buffer: + if row_count < 0: + row_count = len(self._buffer[k]) + elif row_count != len(self._buffer[k]): + self._throw( + "Column `{}` row count {} doesn't equal to the first column row count {}".format( + k, len(self._buffer[k]), row_count + ) + ) + + # output files + if self._file_type == BulkFileType.NPY: + return self._persist_npy(local_path) + if self._file_type == BulkFileType.JSON_RB: + return self._persist_json_rows(local_path) + + self._throw(f"Unsupported file tpye: {self._file_type}") + return [] + + def _persist_npy(self, local_path: str): + Path(local_path).mkdir(exist_ok=True) + + file_list = [] + for k in self._buffer: + full_file_name = Path(local_path).joinpath(k + ".npy") + file_list.append(full_file_name) + try: + np.save(full_file_name, self._buffer[k]) + except Exception as e: + self._throw(f"Failed to persist column-based file {full_file_name}, error: {e}") + + logger.info(f"Successfully persist column-based file {full_file_name}") + + if len(file_list) != len(self._buffer): + logger.error("Some of fields were not persisted successfully, abort the files") + for f in file_list: + Path(f).unlink() + Path(local_path).rmdir() + file_list.clear() + self._throw("Some of fields were not persisted successfully, abort the files") + + return file_list + + def _persist_json_rows(self, local_path: str): + rows = [] + row_count = len(next(iter(self._buffer.values()))) + row_index = 0 + while row_index < row_count: + row = {} + for k, v in self._buffer.items(): + row[k] = v[row_index] + rows.append(row) + row_index = row_index + 1 + + data = { + "rows": rows, + } + file_path = Path(local_path + ".json") + try: + with file_path.open("w") as json_file: + json.dump(data, json_file, indent=2) + except Exception as e: + self._throw(f"Failed to persist row-based file {file_path}, error: {e}") + + logger.info(f"Successfully persist row-based file {file_path}") + return [file_path] diff --git a/pymilvus/bulk_writer/bulk_import.py b/pymilvus/bulk_writer/bulk_import.py new file mode 100644 index 000000000..736fbf76c --- /dev/null +++ b/pymilvus/bulk_writer/bulk_import.py @@ -0,0 +1,156 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import json +import logging + +import requests + +from pymilvus.exceptions import MilvusException + +logger = logging.getLogger("bulk_import") +logger.setLevel(logging.DEBUG) + + +def _http_headers(api_key: str): + return { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_0) AppleWebKit/535.11 (KHTML, like Gecko) " + "Chrome/17.0.963.56 Safari/535.11", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + "Accept-Encodin": "gzip,deflate,sdch", + "Accept-Languag": "en-US,en;q=0.5", + "Authorization": f"Bearer {api_key}", + } + + +def _throw(msg: str): + logger.error(msg) + raise MilvusException(message=msg) + + +def _handle_response(url: str, res: json): + inner_code = res["code"] + if inner_code != 200: + inner_message = res["message"] + _throw(f"Failed to request url: {url}, code: {inner_code}, message: {inner_message}") + + +def _post_request(url: str, api_key: str, params: {}, timeout: int = 20, **kwargs): + try: + resp = requests.post( + url=url, headers=_http_headers(api_key), json=params, timeout=timeout, **kwargs + ) + if resp.status_code != 200: + _throw(f"Failed to post url: {url}, status code: {resp.status_code}") + else: + return resp + except Exception as err: + _throw(f"Failed to post url: {url}, error: {err}") + + +def _get_request(url: str, api_key: str, params: {}, timeout: int = 20, **kwargs): + try: + resp = requests.get( + url=url, headers=_http_headers(api_key), params=params, timeout=timeout, **kwargs + ) + if resp.status_code != 200: + _throw(f"Failed to get url: {url}, status code: {resp.status_code}") + else: + return resp + except Exception as err: + _throw(f"Failed to get url: {url}, error: {err}") + + +## bulkinsert RESTful api wrapper +def bulk_import( + url: str, + api_key: str, + object_url: str, + access_key: str, + secret_key: str, + cluster_id: str, + collection_name: str, + **kwargs, +): + """call bulkinsert restful interface to import files + + Args: + url (str): url of the server + object_url (str): data files url + access_key (str): access key to access the object storage + secret_key (str): secret key to access the object storage + cluster_id (str): id of a milvus instance(for cloud) + collection_name (str): name of the target collection + + Returns: + json: response of the restful interface + """ + request_url = f"https://{url}/v1/vector/collections/import" + params = { + "objectUrl": object_url, + "accessKey": access_key, + "secretKey": secret_key, + "clusterId": cluster_id, + "collectionName": collection_name, + } + + resp = _post_request(url=request_url, api_key=api_key, params=params, **kwargs) + _handle_response(url, resp.json()) + return resp + + +def get_import_progress(url: str, api_key: str, job_id: str, cluster_id: str, **kwargs): + """get job progress + + Args: + url (str): url of the server + job_id (str): a job id + cluster_id (str): id of a milvus instance(for cloud) + + Returns: + json: response of the restful interface + """ + request_url = f"https://{url}/v1/vector/collections/import/get" + params = { + "jobId": job_id, + "clusterId": cluster_id, + } + + resp = _get_request(url=request_url, api_key=api_key, params=params, **kwargs) + _handle_response(url, resp.json()) + return resp + + +def list_import_jobs( + url: str, api_key: str, cluster_id: str, page_size: int, current_page: int, **kwargs +): + """list jobs in a cluster + + Args: + url (str): url of the server + cluster_id (str): id of a milvus instance(for cloud) + page_size (int): pagination size + current_page (int): pagination + + Returns: + json: response of the restful interface + """ + request_url = f"https://{url}/v1/vector/collections/import/list" + params = { + "clusterId": cluster_id, + "pageSize": page_size, + "currentPage": current_page, + } + + resp = _get_request(url=request_url, api_key=api_key, params=params, **kwargs) + _handle_response(url, resp.json()) + return resp diff --git a/pymilvus/bulk_writer/bulk_writer.py b/pymilvus/bulk_writer/bulk_writer.py new file mode 100644 index 000000000..fd888c835 --- /dev/null +++ b/pymilvus/bulk_writer/bulk_writer.py @@ -0,0 +1,129 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import logging +from threading import Lock + +from pymilvus.client.types import DataType +from pymilvus.exceptions import MilvusException +from pymilvus.orm.schema import CollectionSchema + +from .buffer import ( + Buffer, +) +from .constants import ( + TYPE_SIZE, + TYPE_VALIDATOR, + BulkFileType, +) + +logger = logging.getLogger("bulk_writer") +logger.setLevel(logging.DEBUG) + + +class BulkWriter: + def __init__( + self, + schema: CollectionSchema, + segment_size: int, + file_type: BulkFileType = BulkFileType.NPY, + ): + self._schema = schema + self._buffer_size = 0 + self._buffer_row_count = 0 + self._segment_size = segment_size + self._file_type = file_type + self._buffer_lock = Lock() + + if len(self._schema.fields) == 0: + self._throw("collection schema fields list is empty") + + if self._schema.primary_field is None: + self._throw("primary field is null") + + self._buffer = None + self._new_buffer() + + @property + def buffer_size(self): + return self._buffer_size + + @property + def buffer_row_count(self): + return self._buffer_row_count + + @property + def segment_size(self): + return self._segment_size + + def _new_buffer(self): + old_buffer = self._buffer + with self._buffer_lock: + self._buffer = Buffer(self._schema, self._file_type) + return old_buffer + + def append(self, row: dict, **kwargs): + self._verify_row(row) + with self._buffer_lock: + self._buffer.append_row(row) + + def commit(self, **kwargs): + self._buffer_size = 0 + self._buffer_row_count = 0 + + @property + def data_path(self): + return "" + + def _throw(self, msg: str): + logger.error(msg) + raise MilvusException(message=msg) + + def _verify_row(self, row: dict): + if not isinstance(row, dict): + self._throw("The input row must be a dict object") + + row_size = 0 + for field in self._schema.fields: + if field.name not in row: + self._throw(f"The field '{field.name}' is missed in the row") + + if field.is_parimary and field.auto_id: + self._throw(f"The primary key field '{field.name}' is auto-id, no need to provide") + + dtype = DataType(field.dtype) + validator = TYPE_VALIDATOR[dtype.name] + if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}: + dim = field.params["dim"] + if not validator(row[field.name], dim): + self._throw(f"Illegal vector data for vector field '{dtype.name}'") + + vec_size = ( + len(row[field.name]) * 4 + if dtype == DataType.FLOAT_VECTOR + else len(row[field.name]) / 8 + ) + row_size = row_size + vec_size + elif dtype == DataType.VARCHAR: + max_len = field.params["max_length"] + if not validator(row[field.name], max_len): + self._throw(f"Illegal varchar value for field '{dtype.name}'") + + row_size = row_size + len(row[field.name]) + else: + if not validator(row[field.name]): + self._throw(f"Illegal scalar value for field '{dtype.name}'") + + row_size = row_size + TYPE_SIZE[dtype.name] + + self._buffer_size = self._buffer_size + row_size + self._buffer_row_count = self._buffer_row_count + 1 diff --git a/pymilvus/bulk_writer/constants.py b/pymilvus/bulk_writer/constants.py new file mode 100644 index 000000000..beda9da69 --- /dev/null +++ b/pymilvus/bulk_writer/constants.py @@ -0,0 +1,52 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +from enum import IntEnum + +from pymilvus.client.types import ( + DataType, +) + +MB = 1024 * 1024 +GB = 1024 * MB + +DYNAMIC_FIELD_NAME = "$meta" +DEFAULT_BUCKET_NAME = "a-bucket" + +TYPE_SIZE = { + DataType.BOOL.name: 1, + DataType.INT8.name: 8, + DataType.INT16.name: 8, + DataType.INT32.name: 8, + DataType.INT64.name: 8, + DataType.FLOAT.name: 8, + DataType.DOUBLE.name: 8, +} + +TYPE_VALIDATOR = { + DataType.BOOL.name: lambda x: isinstance(x, bool), + DataType.INT8.name: lambda x: isinstance(x, int) and -128 <= x <= 127, + DataType.INT16.name: lambda x: isinstance(x, int) and -32768 <= x <= 32767, + DataType.INT32.name: lambda x: isinstance(x, int) and -2147483648 <= x <= 2147483647, + DataType.INT64.name: lambda x: isinstance(x, int), + DataType.FLOAT.name: lambda x: isinstance(x, float), + DataType.DOUBLE.name: lambda x: isinstance(x, float), + DataType.VARCHAR.name: lambda x, max_len: isinstance(x, str) and len(x) <= max_len, + DataType.JSON.name: lambda x: isinstance(x, dict), + DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim, + DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, bytes) and len(x) * 8 == dim, +} + + +class BulkFileType(IntEnum): + NPY = 1 + JSON_RB = 2 diff --git a/pymilvus/bulk_writer/local_bulk_writer.py b/pymilvus/bulk_writer/local_bulk_writer.py new file mode 100644 index 000000000..cdcb993da --- /dev/null +++ b/pymilvus/bulk_writer/local_bulk_writer.py @@ -0,0 +1,95 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import logging +import threading +import time +import uuid +from pathlib import Path +from threading import Thread +from typing import Callable, Optional + +from pymilvus.orm.schema import CollectionSchema + +from .bulk_writer import BulkWriter +from .constants import ( + MB, + BulkFileType, +) + +logger = logging.getLogger("local_bulk_writer") +logger.setLevel(logging.DEBUG) + + +class LocalBulkWriter(BulkWriter): + def __init__( + self, + schema: CollectionSchema, + local_path: str, + segment_size: int = 512 * MB, + file_type: BulkFileType = BulkFileType.NPY, + ): + super().__init__(schema, segment_size, file_type) + self._local_path = local_path + self._make_dir() + self._flush_count = 0 + self._working_thread = {} + + def __del__(self): + if len(self._working_thread) > 0: + for k, th in self._working_thread.items(): + logger.info(f"Wait thread '{k}' to finish") + th.join() + + def _make_dir(self): + Path(self._local_path).mkdir(exist_ok=True) + uidir = Path(self._local_path).joinpath(str(uuid.uuid4())) + Path(uidir).mkdir(exist_ok=True) + self._local_path = uidir + logger.info(f"Local buffer writer initialized, target path: {uidir}") + + def append(self, row: dict, **kwargs): + super().append(row, **kwargs) + + if super().buffer_size > super().segment_size: + self.commit(_async=True) + + def commit(self, **kwargs): + while len(self._working_thread) > 0: + logger.info("Previous flush action is not finished, waiting...") + time.sleep(0.5) + + logger.info( + f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}" + ) + _async = kwargs.get("_async", False) + call_back = kwargs.get("call_back", None) + x = Thread(target=self._flush, args=(call_back,)) + x.start() + if not _async: + x.join() + super().commit() # reset the buffer size + + def _flush(self, call_back: Optional[Callable] = None): + self._working_thread[threading.current_thread().name] = threading.current_thread() + self._flush_count = self._flush_count + 1 + target_path = Path.joinpath(self._local_path, str(self._flush_count)) + + old_buffer = super()._new_buffer() + file_list = old_buffer.persist(str(target_path)) + if call_back: + call_back(file_list) + del self._working_thread[threading.current_thread().name] + + @property + def data_path(self): + return self._local_path diff --git a/pymilvus/bulk_writer/remote_bulk_writer.py b/pymilvus/bulk_writer/remote_bulk_writer.py new file mode 100644 index 000000000..e460f519d --- /dev/null +++ b/pymilvus/bulk_writer/remote_bulk_writer.py @@ -0,0 +1,155 @@ +# Copyright (C) 2019-2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import logging +from pathlib import Path +from typing import Any, Optional + +from minio import Minio +from minio.error import S3Error + +from pymilvus.orm.schema import CollectionSchema + +from .constants import ( + DEFAULT_BUCKET_NAME, + MB, +) +from .local_bulk_writer import LocalBulkWriter + +logger = logging.getLogger("remote_bulk_writer") +logger.setLevel(logging.DEBUG) + + +class RemoteBulkWriter(LocalBulkWriter): + class ConnectParam: + def __init__( + self, + bucket_name: str = DEFAULT_BUCKET_NAME, + endpoint: Optional[str] = None, + access_key: Optional[str] = None, + secret_key: Optional[str] = None, + secure: bool = False, + session_token: Optional[str] = None, + region: Optional[str] = None, + http_client: Any = None, + credentials: Any = None, + ): + self._bucket_name = bucket_name + self._endpoint = endpoint + self._access_key = access_key + self._secret_key = secret_key + self._secure = (secure,) + self._session_token = (session_token,) + self._region = (region,) + self._http_client = (http_client,) # urllib3.poolmanager.PoolManager + self._credentials = (credentials,) # minio.credentials.Provider + + def __init__( + self, + schema: CollectionSchema, + remote_path: str, + local_path: str, + connect_param: ConnectParam, + segment_size: int = 512 * MB, + ): + super().__init__(schema, local_path, segment_size) + uid = Path(super().data_path).name + self._remote_path = Path("/").joinpath(remote_path).joinpath(uid) + self._connect_param = connect_param + self._client = None + self._get_client() + logger.info(f"Remote buffer writer initialized, target path: {self._remote_path}") + + def _get_client(self): + try: + if self._client is None: + + def arg_parse(arg: Any): + return arg[0] if isinstance(arg, tuple) else arg + + self._client = Minio( + endpoint=arg_parse(self._connect_param._endpoint), + access_key=arg_parse(self._connect_param._access_key), + secret_key=arg_parse(self._connect_param._secret_key), + secure=arg_parse(self._connect_param._secure), + session_token=arg_parse(self._connect_param._session_token), + region=arg_parse(self._connect_param._region), + http_client=arg_parse(self._connect_param._http_client), + credentials=arg_parse(self._connect_param._credentials), + ) + else: + return self._client + except Exception as err: + logger.error(f"Failed to connect MinIO/S3, error: {err}") + raise + + def append(self, row: dict, **kwargs): + super().append(row, **kwargs) + + def commit(self, **kwargs): + super().commit(call_back=self._upload) + + def _remote_exists(self, file: str) -> bool: + try: + minio_client = self._get_client() + minio_client.stat_object(bucket_name=self._connect_param._bucket_name, object_name=file) + except S3Error as err: + if err.code == "NoSuchKey": + return False + self._throw(f"Failed to stat MinIO/S3 object, error: {err}") + return True + + def _local_rm(self, file: str): + try: + Path.unlink(file) + except Exception: + logger.warning(f"Failed to delete local file: {file}") + + def _upload(self, file_list: list): + remote_files = [] + try: + logger.info("Prepare to upload files") + minio_client = self._get_client() + found = minio_client.bucket_exists(self._connect_param._bucket_name) + if not found: + self._throw(f"MinIO bucket '{self._connect_param._bucket_name}' doesn't exist") + + for file_path in file_list: + ext = Path(file_path).suffix + if ext not in {".json", ".npy"}: + continue + + relative_file_path = str(file_path).replace(str(super().data_path), "") + minio_file_path = str( + Path.joinpath(self._remote_path, relative_file_path.lstrip("/")) + ) + + if not self._remote_exists(minio_file_path): + minio_client.fput_object( + bucket_name=self._connect_param._bucket_name, + object_name=minio_file_path, + file_path=file_path, + ) + logger.info(f"Upload file '{file_path}' to '{minio_file_path}'") + else: + logger.info(f"Remote file '{minio_file_path}' already exists") + remote_files.append(minio_file_path) + self._local_rm(file_path) + except Exception as e: + self._throw(f"Failed to call MinIO/S3 api, error: {e}") + + logger.info(f"Successfully upload files: {file_list}") + return remote_files + + @property + def data_path(self): + return self._remote_path diff --git a/pyproject.toml b/pyproject.toml index 069b7aaf9..26cd51a21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies=[ "ujson>=2.0.0", "pandas>=1.2.4", "numpy<1.25.0;python_version<='3.8'", + "requests", + "minio", ] classifiers=[ diff --git a/requirements.txt b/requirements.txt index 172ca465e..38574fbc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,5 @@ pytest-timeout==1.3.4 pandas>=1.1.5 ruff black +requests +minio