Skip to content

Commit

Permalink
[data] Add a parameter to allow overriding LanceDB scanner options (#…
Browse files Browse the repository at this point in the history
…46975)

The default `batch_size` for the LanceDatasource is too big for some use
cases. This PR adds a new `scanner_options` argument to allow overriding
the LanceDB scanner options.

---------

Signed-off-by: Hao Chen <[email protected]>
  • Loading branch information
raulchen authored Aug 6, 2024
1 parent 45250ca commit d169386
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
24 changes: 14 additions & 10 deletions python/ray/data/_internal/datasource/lance_datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional

import numpy as np

Expand All @@ -23,14 +23,18 @@ def __init__(
columns: Optional[List[str]] = None,
filter: Optional[str] = None,
storage_options: Optional[Dict[str, str]] = None,
scanner_options: Optional[Dict[str, Any]] = None,
):
_check_import(self, module="lance", package="pylance")

import lance

self.uri = uri
self.columns = columns
self.filter = filter
self.scanner_options = scanner_options or {}
if columns is not None:
self.scanner_options["columns"] = columns
if filter is not None:
self.scanner_options["filter"] = filter
self.storage_options = storage_options
self.lance_ds = lance.dataset(uri=uri, storage_options=storage_options)

Expand All @@ -54,14 +58,11 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
size_bytes=None,
exec_stats=None,
)
columns = self.columns
row_filter = self.filter
scanner_options = self.scanner_options
lance_ds = self.lance_ds

read_task = ReadTask(
lambda f=fragment_ids: _read_fragments(
f, lance_ds, columns, row_filter
),
lambda f=fragment_ids: _read_fragments(f, lance_ds, scanner_options),
metadata,
)
read_tasks.append(read_task)
Expand All @@ -74,7 +75,9 @@ def estimate_inmemory_data_size(self) -> Optional[int]:


def _read_fragments(
fragment_ids, lance_ds, columns, row_filter
fragment_ids,
lance_ds,
scanner_options,
) -> Iterator["pyarrow.Table"]:
"""Read Lance fragments in batches.
Expand All @@ -84,6 +87,7 @@ def _read_fragments(
import pyarrow

fragments = [lance_ds.get_fragment(id) for id in fragment_ids]
scanner = lance_ds.scanner(columns, filter=row_filter, fragments=fragments)
scanner_options["fragments"] = fragments
scanner = lance_ds.scanner(**scanner_options)
for batch in scanner.to_reader():
yield pyarrow.Table.from_batches([batch])
6 changes: 6 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3007,6 +3007,7 @@ def read_lance(
columns: Optional[List[str]] = None,
filter: Optional[str] = None,
storage_options: Optional[Dict[str, str]] = None,
scanner_options: Optional[Dict[str, Any]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
Expand All @@ -3033,6 +3034,10 @@ def read_lance(
connection. This is used to store connection parameters like credentials,
endpoint, etc. For more information, see `Object Store Configuration <https\
://lancedb.github.io/lance/read_and_write.html#object-store-configuration>`_.
scanner_options: Additional options to configure the `LanceDataset.scanner()`
method, such as `batch_size`. For more information,
see `LanceDB API doc <https://lancedb.github.io\
/lance/api/python/lance.html#lance.dataset.LanceDataset.scanner>`_
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
Expand All @@ -3051,6 +3056,7 @@ def read_lance(
columns=columns,
filter=filter,
storage_options=storage_options,
scanner_options=scanner_options,
)

return read_datasource(
Expand Down
11 changes: 9 additions & 2 deletions python/ray/data/tests/test_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
),
],
)
def test_lance_read_basic(fs, data_path):
@pytest.mark.parametrize(
"batch_size",
[None, 100],
)
def test_lance_read_basic(fs, data_path, batch_size):
# NOTE: Lance only works with PyArrow 12 or above.
pyarrow_version = _get_pyarrow_version()
if pyarrow_version is not None:
Expand All @@ -51,7 +55,10 @@ def test_lance_read_basic(fs, data_path):
)
ds_lance.merge(df2, "one")

ds = ray.data.read_lance(path)
if batch_size is None:
ds = ray.data.read_lance(path)
else:
ds = ray.data.read_lance(path, scanner_options={"batch_size": batch_size})

# Test metadata-only ops.
assert ds.count() == 6
Expand Down

0 comments on commit d169386

Please sign in to comment.