Skip to content

Commit

Permalink
Accept Python dict from client "write_dataframe" and TableAdapter (#771)
Browse files Browse the repository at this point in the history
* Accept dict for client 'write_dataframe'

* Add test for writing dataframe from dict

* Better generic dict name in TableStructure

* Add support for dict to TableAdapter

* Simplify generated_minimal example

* Use newer TableAdapter name, rather than alias

* Update changelog

* Rename from_dict methods

* Remove commented ignore for Pandas warning from new test
  • Loading branch information
nmaytan authored Aug 6, 2024
1 parent 2e392d6 commit 17589bb
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 11 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Write the date in place of the "Unreleased" in the case a new version is release

## Unreleased

### Added
- Add method to `TableAdapter` which accepts a Python dictionary.

### Changed
- Make `tiled.client` accept a Python dictionary when fed to `write_dataframe()`.
- The `generated_minimal` example no longer requires pandas and instead uses a Python dict.

### Fixed
- A bug in `Context.__getstate__` caused picking to fail if applied twice.

Expand Down
25 changes: 25 additions & 0 deletions tiled/_tests/test_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,31 @@ def test_write_dataframe_partitioned(tree):
assert result.specs == specs


def test_write_dataframe_dict(tree):
with Context.from_app(
build_app(tree, validation_registry=validation_registry)
) as context:
client = from_context(context)

data = {f"Column{i}": (1 + i) * numpy.ones(5) for i in range(5)}
df = pandas.DataFrame(data)
metadata = {"scan_id": 1, "method": "A"}
specs = [Spec("SomeSpec")]

with record_history() as history:
client.write_dataframe(data, metadata=metadata, specs=specs)
# one request for metadata, one for data
assert len(history.requests) == 1 + 1

results = client.search(Key("scan_id") == 1)
result = results.values().first()
result_dataframe = result.read()

pandas.testing.assert_frame_equal(result_dataframe, df)
assert result.metadata == metadata
assert result.specs == specs


@pytest.mark.parametrize(
"coo",
[
Expand Down
32 changes: 32 additions & 0 deletions tiled/adapters/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,38 @@ def from_pandas(
ddf, metadata=metadata, specs=specs, access_policy=access_policy
)

@classmethod
def from_dict(
cls,
*args: Any,
metadata: Optional[JSON] = None,
specs: Optional[List[Spec]] = None,
access_policy: Optional[AccessPolicy] = None,
npartitions: int = 1,
**kwargs: Any,
) -> "TableAdapter":
"""
Parameters
----------
args :
metadata :
specs :
access_policy :
npartitions :
kwargs :
Returns
-------
"""
ddf = dask.dataframe.from_dict(*args, npartitions=npartitions, **kwargs)
if specs is None:
specs = [Spec("dataframe")]
return cls.from_dask_dataframe(
ddf, metadata=metadata, specs=specs, access_policy=access_policy
)

@classmethod
def from_dask_dataframe(
cls,
Expand Down
2 changes: 2 additions & 0 deletions tiled/client/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,8 @@ def write_dataframe(

if isinstance(dataframe, dask.dataframe.DataFrame):
structure = TableStructure.from_dask_dataframe(dataframe)
elif isinstance(dataframe, dict):
structure = TableStructure.from_dict(dataframe)
else:
structure = TableStructure.from_pandas(dataframe)
client = self.new(
Expand Down
17 changes: 7 additions & 10 deletions tiled/examples/generated_minimal.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import numpy
import pandas
import xarray

from tiled.adapters.array import ArrayAdapter
from tiled.adapters.dataframe import DataFrameAdapter
from tiled.adapters.dataframe import TableAdapter
from tiled.adapters.mapping import MapAdapter
from tiled.adapters.xarray import DatasetAdapter

tree = MapAdapter(
{
"A": ArrayAdapter.from_array(numpy.ones((100, 100))),
"B": ArrayAdapter.from_array(numpy.ones((100, 100, 100))),
"C": DataFrameAdapter.from_pandas(
pandas.DataFrame(
{
"x": 1 * numpy.ones(100),
"y": 2 * numpy.ones(100),
"z": 3 * numpy.ones(100),
}
),
"C": TableAdapter.from_dict(
{
"x": 1 * numpy.ones(100),
"y": 2 * numpy.ones(100),
"z": 3 * numpy.ones(100),
},
npartitions=3,
),
"D": DatasetAdapter.from_dataset(
Expand Down
5 changes: 4 additions & 1 deletion tiled/serialization/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
def serialize_arrow(df, metadata, preserve_index=True):
import pyarrow

table = pyarrow.Table.from_pandas(df, preserve_index=preserve_index)
if isinstance(df, dict):
table = pyarrow.Table.from_pydict(df)
else:
table = pyarrow.Table.from_pandas(df, preserve_index=preserve_index)
sink = pyarrow.BufferOutputStream()
with pyarrow.ipc.new_file(sink, table.schema) as writer:
writer.write_table(table)
Expand Down
9 changes: 9 additions & 0 deletions tiled/structures/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def from_pandas(cls, df):
data_uri = B64_ENCODED_PREFIX + schema_b64
return cls(arrow_schema=data_uri, npartitions=1, columns=list(df.columns))

@classmethod
def from_dict(cls, d):
import pyarrow

schema_bytes = pyarrow.Table.from_pydict(d).schema.serialize()
schema_b64 = base64.b64encode(schema_bytes).decode("utf-8")
data_uri = B64_ENCODED_PREFIX + schema_b64
return cls(arrow_schema=data_uri, npartitions=1, columns=list(d.keys()))

@property
def arrow_schema_decoded(self):
import pyarrow
Expand Down

0 comments on commit 17589bb

Please sign in to comment.