diff --git a/tests/table/test_broadcasts.py b/tests/table/test_broadcasts.py index 29133e3932..196cc64dd4 100644 --- a/tests/table/test_broadcasts.py +++ b/tests/table/test_broadcasts.py @@ -2,12 +2,22 @@ import pytest +import daft from daft.expressions import col, lit from daft.table import Table -@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, object()]) +@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, [1, 2, 3], object()]) def test_broadcast(data): table = Table.from_pydict({"x": [1, 2, 3]}) new_table = table.eval_expression_list([col("x"), lit(data)]) assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]} + + +def test_broadcast_fixed_size_list(): + data = [1, 2, 3] + table = Table.from_pydict({"x": [1, 2, 3]}) + new_table = table.eval_expression_list( + [col("x"), lit(data).cast(daft.DataType.fixed_size_list("foo", daft.DataType.int64(), 3))] + ) + assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]} diff --git a/tests/table/test_take.py b/tests/table/test_take.py index 7ba637db6e..03ec2b1dcf 100644 --- a/tests/table/test_take.py +++ b/tests/table/test_take.py @@ -154,3 +154,43 @@ def test_table_take_pyobject() -> None: assert taken.column_names() == ["objs"] assert taken.to_pydict()["objs"] == [objects[3], objects[2], objects[2], objects[2], objects[3]] + + +@pytest.mark.parametrize("idx_dtype", daft_int_types) +def test_table_take_fixed_size_list(idx_dtype) -> None: + pa_table = pa.Table.from_pydict( + { + "a": pa.array([[1, 2], [3, None], None, [None, None]], type=pa.list_(pa.int64(), 2)), + "b": pa.array([[4, 5], [6, None], None, [None, None]], type=pa.list_(pa.int64(), 2)), + } + ) + daft_table = Table.from_arrow(pa_table) + assert len(daft_table) == 4 + assert daft_table.column_names() == ["a", "b"] + + indices = Series.from_pylist([0, 1]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 2 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == {"a": [[1, 2], [3, None]], "b": [[4, 5], [6, None]]} + + indices = Series.from_pylist([3, 2]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 2 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == {"a": [[None, None], None], "b": [[None, None], None]} + + indices = Series.from_pylist([3, 2, 2, 2, 3]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 5 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == { + "a": [[None, None], None, None, None, [None, None]], + "b": [[None, None], None, None, None, [None, None]], + }