Skip to content

Commit

Permalink
Add schema option
Browse files Browse the repository at this point in the history
  • Loading branch information
janheinrichmerker committed Jul 3, 2024
1 parent bd87f3a commit c17d413
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
31 changes: 31 additions & 0 deletions examples/read_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from os import environ
from pyarrow import field, schema, struct, int32
from ray import init
from ray.data import read_datasource
from ray_elasticsearch import ElasticsearchDatasource

init()
source = ElasticsearchDatasource(
index=environ["ELASTICSEARCH_INDEX"],
client_kwargs=dict(
hosts=environ["ELASTICSEARCH_HOST"],
http_auth=(
environ["ELASTICSEARCH_USERNAME"],
environ["ELASTICSEARCH_PASSWORD"],
),
),
schema=schema([
field(
name="_source",
type=struct([
field(name="id", type=int32(), nullable=False)
]),
nullable=False,
)
])
)
print(f"Num rows: {source.num_rows()}")
res = read_datasource(source)\
.map(lambda x: x["_source"])\
.sum("id")
print(f"Read complete. Sum: {res}")
16 changes: 12 additions & 4 deletions ray_elasticsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing_extensions import TypeAlias

from pandas import DataFrame
from pyarrow import Table
from pyarrow import Schema, Table
from ray.data import Datasource, ReadTask, Datasink
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import BlockMetadata, Block
Expand Down Expand Up @@ -48,6 +48,7 @@ class ElasticsearchDatasource(Datasource):
_keep_alive: str
_chunk_size: int
_client_kwargs: dict[str, Any]
_schema: Optional[Union[type, Schema]]

def __init__(
self,
Expand All @@ -56,20 +57,22 @@ def __init__(
keep_alive: str = "5m",
chunk_size: int = 1000,
client_kwargs: dict[str, Any] = {},
schema: Optional[Union[type, Schema]] = None,
) -> None:
super().__init__()
self._index = index
self._query = query
self._keep_alive = keep_alive
self._chunk_size = chunk_size
self._client_kwargs = client_kwargs
self._schema = schema

@property
def _elasticsearch(self) -> Elasticsearch:
return Elasticsearch(**self._client_kwargs)

def schema(self) -> None:
return None
def schema(self) -> Optional[Union[type, Schema]]:
return self._schema

@cached_property
def _num_rows(self) -> int:
Expand Down Expand Up @@ -102,11 +105,12 @@ def _get_read_task(
slice_max: int,
chunk_size: int,
client_kwargs: dict[str, Any],
schema: Optional[Union[type, Schema]],
) -> ReadTask:
metadata = BlockMetadata(
num_rows=None,
size_bytes=None,
schema=None,
schema=schema,
input_files=None,
exec_stats=None,
)
Expand Down Expand Up @@ -151,6 +155,7 @@ def get_read_tasks(self, parallelism: int) -> list[ReadTask]:
slice_max=parallelism,
chunk_size=self._chunk_size,
client_kwargs=self._client_kwargs,
schema=self._schema,
)
for i in range(parallelism)
]
Expand Down Expand Up @@ -290,6 +295,7 @@ def __init__(
keep_alive: str = "5m",
chunk_size: int = 1000,
client_kwargs: dict[str, Any] = {},
schema: Optional[Union[type, Schema]] = None,
) -> None:
super().__init__(
index=(
Expand All @@ -304,6 +310,8 @@ def __init__(
keep_alive=keep_alive,
chunk_size=chunk_size,
client_kwargs=client_kwargs,
# TODO: Infer schema from document type if not given.
schema=schema,
)

class ElasticsearchDslDatasink(ElasticsearchDatasink):
Expand Down

0 comments on commit c17d413

Please sign in to comment.