Skip to content

Commit

Permalink
fix: Raise on invalid shape of shape 1, empty combination (#19113)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 6, 2024
1 parent f7de80c commit 8c306dd
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
25 changes: 14 additions & 11 deletions crates/polars-mem-engine/src/executors/projection_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ pub(super) fn check_expand_literals(
}
}
}

// If all series are the same length it is ok. If not we can broadcast Series of length one.
if !all_equal_len && should_broadcast {
selected_columns = selected_columns
Expand All @@ -300,32 +301,34 @@ pub(super) fn check_expand_literals(
Ok(match series.len() {
0 if df_height == 1 => series,
1 => {
if has_empty {
polars_ensure!(df_height == 1,
ComputeError: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
);
series.slice(0, 0)
} else if df_height == 1 {
if !has_empty && df_height == 1 {
series
} else {
if has_empty {
polars_ensure!(df_height == 1,
ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
);

}

if verify_scalar {
polars_ensure!(phys.is_scalar(),
InvalidOperation: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\
ShapeMismatch: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
series.name(), series.len(), df_height
series.name(), series.len(), df_height *(!has_empty as usize)
);

}
series.new_from_index(0, df_height)
series.new_from_index(0, df_height * (!has_empty as usize) )
}
},
len if len == df_height => {
series
},
_ => {
polars_bail!(
ComputeError: "Series length {} doesn't match the DataFrame height of {}",
ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
)
}
Expand Down
12 changes: 0 additions & 12 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,6 @@ def test_df_init_nested_mixed_types() -> None:
assert df.to_dicts() == [{"key": [{"value": 1.0}, {"value": 1.0}]}]


def test_unit_and_empty_construction_15896() -> None:
# This is still incorrect.
# We should raise, but currently for len 1 dfs,
# we cannot tell if they come from a literal or expression.
assert "shape: (0, 2)" in str(
pl.DataFrame({"A": [0]}).select(
C="A",
A=pl.int_range("A"), # creates empty series
)
)


class CustomSchema(Mapping[str, Any]):
"""Dummy schema object for testing compatibility with Mapping."""

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/dataframe/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

import polars as pl


# TODO: remove this skip when streaming raises
@pytest.mark.may_fail_auto_streaming
def test_raise_invalid_shape_19108() -> None:
df = pl.DataFrame({"foo": [1, 2], "bar": [3, 4]})
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col.foo.head(0), pl.col.bar.head(1))
3 changes: 1 addition & 2 deletions py-polars/tests/unit/lazyframe/test_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


Expand All @@ -19,7 +18,7 @@ def test_with_context() -> None:

with pytest.deprecated_call():
context = df_a.with_context(df_b.lazy())
with pytest.raises(ComputeError):
with pytest.raises(pl.exceptions.ShapeError):
context.select("a", "c").collect()


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_invalid_broadcast() -> None:
"group": [0, 1],
}
)
with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col("group").filter(pl.col("group") == 0), "a")


Expand Down

0 comments on commit 8c306dd

Please sign in to comment.