Skip to content

Commit

Permalink
feat: progress tracking in write_dataset and write_fragments
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Oct 26, 2023
1 parent edf300e commit 031da35
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 115 deletions.
8 changes: 7 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import pyarrow as pa
import pyarrow.dataset
from lance.optimize import Compaction
from lance.progress import FragmentWriteProgress
from pyarrow import RecordBatch, Schema
from pyarrow._compute import Expression

Expand Down Expand Up @@ -1608,6 +1609,7 @@ def write_dataset(
max_rows_per_group: int = 1024,
max_bytes_per_file: int = 90 * 1024 * 1024 * 1024,
commit_lock: Optional[CommitLock] = None,
progress: Optional[FragmentWriteProgress] = None,
) -> LanceDataset:
"""Write a given data_obj to the given uri
Expand Down Expand Up @@ -1639,7 +1641,10 @@ def write_dataset(
commit_lock : CommitLock, optional
A custom commit lock. Only needed if your object store does not support
atomic commits. See the user guide for more details.
progress: FragmentWriteProgress, optional
*Experimental API*. Progress tracking for writing the fragment. Pass
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
"""
reader = _coerce_reader(data_obj, schema)
_validate_schema(reader.schema)
Expand All @@ -1650,6 +1655,7 @@ def write_dataset(
"max_rows_per_file": max_rows_per_file,
"max_rows_per_group": max_rows_per_group,
"max_bytes_per_file": max_bytes_per_file,
"progress": progress,
}

if commit_lock:
Expand Down
12 changes: 7 additions & 5 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def create(
max_rows_per_group: int, default 1024
The maximum number of rows per group in the data file.
progress: FragmentWriteProgress, optional
*Experimental API*. Progress tracking for writing the fragment.
*Experimental API*. Progress tracking for writing the fragment. Pass
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
See Also
--------
Expand Down Expand Up @@ -425,7 +427,9 @@ def write_fragments(
defaults to 90 GB, since we have a hard limit of 100 GB per file on
object stores.
progress : FragmentWriteProgress, optional
*Experimental API*. Progress tracking for writing the fragment.
*Experimental API*. Progress tracking for writing the fragment. Pass
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
Returns
-------
Expand All @@ -447,15 +451,13 @@ def write_fragments(

if isinstance(dataset_uri, Path):
dataset_uri = str(dataset_uri)
if progress is None:
progress = NoopFragmentWriteProgress()

fragments = _write_fragments(
dataset_uri,
reader,
progress,
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
max_bytes_per_file=max_bytes_per_file,
progress=progress,
)
return [FragmentMetadata(frag.json()) for frag in fragments]
19 changes: 19 additions & 0 deletions python/python/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import pyarrow as pa
import pytest
import semver
from lance.fragment import FragmentMetadata
from lance.progress import FragmentWriteProgress

PYARROW_VERSION = semver.VersionInfo.parse(pa.__version__)

requires_pyarrow_12 = pytest.mark.skipif(
PYARROW_VERSION.major < 12, reason="requires arrow 12+"
)


class ProgressForTest(FragmentWriteProgress):
def __init__(self):
super().__init__()
self.begin_called = 0
self.complete_called = 0

def begin(
self, fragment: FragmentMetadata, multipart_id: Optional[str] = None, **kwargs
):
self.begin_called += 1

def complete(self, fragment: FragmentMetadata):
self.complete_called += 1
10 changes: 10 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import pytest
from helper import ProgressForTest
from lance.commit import CommitConflictError

# Various valid inputs for write_dataset
Expand Down Expand Up @@ -882,3 +883,12 @@ def test_scan_with_row_ids(tmp_path: Path):

tbl2 = ds._take_rows(row_ids)
assert tbl2["a"] == tbl["a"]


def test_dataset_progress(tmp_path: Path):
data = pa.table({"a": range(10)})
progress = ProgressForTest()
ds = lance.write_dataset(data, tmp_path, max_rows_per_file=5, progress=progress)
assert len(ds.get_fragments()) == 2
assert progress.begin_called == 2
assert progress.complete_called == 2
29 changes: 11 additions & 18 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import json
import multiprocessing
from pathlib import Path
from typing import Optional

import pandas as pd
import pyarrow as pa
import pytest
from helper import ProgressForTest
from lance import FragmentMetadata, LanceDataset, LanceFragment, LanceOperation
from lance.fragment import write_fragments
from lance.progress import FileSystemFragmentWriteProgress, FragmentWriteProgress
from lance.progress import FileSystemFragmentWriteProgress


def test_write_fragment(tmp_path: Path):
Expand Down Expand Up @@ -68,26 +68,19 @@ def test_write_fragments(tmp_path: Path):
"a": pa.array(range(1024)),
}
)
progress = ProgressForTest()
fragments = write_fragments(
tab, tmp_path, max_rows_per_group=512, max_bytes_per_file=1024
tab,
tmp_path,
max_rows_per_group=512,
max_bytes_per_file=1024,
progress=progress,
)
assert len(fragments) == 2
assert all(isinstance(f, FragmentMetadata) for f in fragments)


class ProgressForTest(FragmentWriteProgress):
def __init__(self):
super().__init__()
self.begin_called = 0
self.complete_called = 0

def begin(
self, fragment: FragmentMetadata, multipart_id: Optional[str] = None, **kwargs
):
self.begin_called += 1

def complete(self, fragment: FragmentMetadata):
self.complete_called += 1
# progress hook was called for each fragment
assert progress.begin_called == 2
assert progress.complete_called == 2


def test_write_fragment_with_progress(tmp_path: Path):
Expand Down
58 changes: 58 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ use arrow::pyarrow::{ToPyArrow, *};
use arrow_array::{Float32Array, RecordBatch};
use arrow_data::ArrayData;
use arrow_schema::Schema as ArrowSchema;
use async_trait::async_trait;
use chrono::Duration;

use lance::arrow::as_fixed_size_list_array;
use lance::dataset::progress::WriteFragmentProgress;
use lance::dataset::{
fragment::FileFragment as LanceFileFragment, scanner::Scanner as LanceScanner,
transaction::Operation as LanceOperation, Dataset as LanceDataset, ReadParams, Version,
Expand All @@ -44,6 +47,7 @@ use pyo3::{
types::{IntoPyDict, PyBool, PyDict, PyFloat, PyInt, PyLong},
PyObject, PyResult,
};
use snafu::{location, Location};

use crate::fragment::{FileFragment, FragmentMetadata};
use crate::Scanner;
Expand Down Expand Up @@ -761,9 +765,63 @@ pub fn get_write_params(options: &PyDict) -> PyResult<Option<WriteParams>> {
if let Some(maybe_nbytes) = options.get_item("max_bytes_per_file") {
p.max_bytes_per_file = usize::extract(maybe_nbytes)?;
}
if let Some(progress) = options.get_item("progress") {
if !progress.is_none() {
p.progress = Arc::new(PyWriteProgress::new(progress.to_object(options.py())));
}
}

p.store_params = get_object_store_params(options);

Some(p)
};
Ok(params)
}

#[pyclass(name = "_FragmentWriteProgress", module = "_lib")]
#[derive(Debug)]
pub struct PyWriteProgress {
/// A Python object that implements the `WriteFragmentProgress` trait.
py_obj: PyObject,
}

impl PyWriteProgress {
fn new(obj: PyObject) -> Self {
Self { py_obj: obj }
}
}

#[async_trait]
impl WriteFragmentProgress for PyWriteProgress {
async fn begin(&self, fragment: &Fragment, multipart_id: &str) -> lance::Result<()> {
let json_str = serde_json::to_string(fragment)?;

Python::with_gil(|py| -> PyResult<()> {
let kwargs = PyDict::new(py);
kwargs.set_item("multipart_id", multipart_id)?;
self.py_obj
.call_method(py, "_do_begin", (json_str,), Some(kwargs))?;
Ok(())
})
.map_err(|e| lance::Error::IO {
message: format!("Failed to call begin() on WriteFragmentProgress: {}", e),
location: location!(),
})?;
Ok(())
}

async fn complete(&self, fragment: &Fragment) -> lance::Result<()> {
let json_str = serde_json::to_string(fragment)?;

Python::with_gil(|py| -> PyResult<()> {
self.py_obj
.call_method(py, "_do_complete", (json_str,), None)?;
Ok(())
})
.map_err(|e| lance::Error::IO {
message: format!("Failed to call complete() on WriteFragmentProgress: {}", e),
location: location!(),
})?;
Ok(())
}
}
Loading

0 comments on commit 031da35

Please sign in to comment.