Skip to content

Commit

Permalink
Unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Aug 26, 2023
1 parent 7be2d7e commit 4b89cd3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/table/test_broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@

import pytest

import daft
from daft.expressions import col, lit
from daft.table import Table


@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, object()])
@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, [1, 2, 3], object()])
def test_broadcast(data):
table = Table.from_pydict({"x": [1, 2, 3]})
new_table = table.eval_expression_list([col("x"), lit(data)])
assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]}


def test_broadcast_fixed_size_list():
data = [1, 2, 3]
table = Table.from_pydict({"x": [1, 2, 3]})
new_table = table.eval_expression_list(
[col("x"), lit(data).cast(daft.DataType.fixed_size_list("foo", daft.DataType.int64(), 3))]
)
assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]}
40 changes: 40 additions & 0 deletions tests/table/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,43 @@ def test_table_take_pyobject() -> None:
assert taken.column_names() == ["objs"]

assert taken.to_pydict()["objs"] == [objects[3], objects[2], objects[2], objects[2], objects[3]]


@pytest.mark.parametrize("idx_dtype", daft_int_types)
def test_table_take_fixed_size_list(idx_dtype) -> None:
pa_table = pa.Table.from_pydict(
{
"a": pa.array([[1, 2], [3, None], None, [None, None]], type=pa.list_(pa.int64(), 2)),
"b": pa.array([[4, 5], [6, None], None, [None, None]], type=pa.list_(pa.int64(), 2)),
}
)
daft_table = Table.from_arrow(pa_table)
assert len(daft_table) == 4
assert daft_table.column_names() == ["a", "b"]

indices = Series.from_pylist([0, 1]).cast(idx_dtype)

taken = daft_table.take(indices)
assert len(taken) == 2
assert taken.column_names() == ["a", "b"]

assert taken.to_pydict() == {"a": [[1, 2], [3, None]], "b": [[4, 5], [6, None]]}

indices = Series.from_pylist([3, 2]).cast(idx_dtype)

taken = daft_table.take(indices)
assert len(taken) == 2
assert taken.column_names() == ["a", "b"]

assert taken.to_pydict() == {"a": [[None, None], None], "b": [[None, None], None]}

indices = Series.from_pylist([3, 2, 2, 2, 3]).cast(idx_dtype)

taken = daft_table.take(indices)
assert len(taken) == 5
assert taken.column_names() == ["a", "b"]

assert taken.to_pydict() == {
"a": [[None, None], None, None, None, [None, None]],
"b": [[None, None], None, None, None, [None, None]],
}

0 comments on commit 4b89cd3

Please sign in to comment.