diff --git a/src/daft-core/src/array/ops/if_else.rs b/src/daft-core/src/array/ops/if_else.rs index 97fc6301a1..c218af40e5 100644 --- a/src/daft-core/src/array/ops/if_else.rs +++ b/src/daft-core/src/array/ops/if_else.rs @@ -23,9 +23,9 @@ fn generic_if_else( None => Ok(T::full_null(name, dtype, lhs_len).into_series()), Some(predicate_scalar_value) => { if predicate_scalar_value { - Ok(lhs.clone().into_series()) + Ok(lhs.clone().into_series().rename(name)) } else { - Ok(rhs.clone().into_series()) + Ok(rhs.clone().into_series().rename(name)) } } }; diff --git a/tests/table/test_if_else.py b/tests/table/test_if_else.py new file mode 100644 index 0000000000..f1eaa2c3a0 --- /dev/null +++ b/tests/table/test_if_else.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col +from daft.table.micropartition import MicroPartition + + +@pytest.mark.parametrize( + ["predicate", "if_true", "if_false", "expected"], + [ + # Single row + ([True], [1], [2], [1]), + ([False], [1], [2], [2]), + # Multiple rows + ([True, False, True], [1, 2, 3], [4, 5, 6], [1, 5, 3]), + ([False, False, False], [1, 2, 3], [4, 5, 6], [4, 5, 6]), + ], +) +def test_table_expr_if_else(predicate, if_true, if_false, expected) -> None: + daft_table = MicroPartition.from_pydict({"predicate": predicate, "if_true": if_true, "if_false": if_false}) + daft_table = daft_table.eval_expression_list([col("predicate").if_else(col("if_true"), col("if_false"))]) + pydict = daft_table.to_pydict() + + assert pydict["if_true"] == expected