Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Hide polars.testing.* in pytest stack traces #14399

Merged
merged 1 commit into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def assert_frame_equal(
...
AssertionError: values for column 'a' are different
"""
__tracebackhide__ = True

lazy = _assert_correct_input_type(left, right)
objects = "LazyFrames" if lazy else "DataFrames"

Expand Down Expand Up @@ -132,6 +134,8 @@ def assert_frame_equal(
def _assert_correct_input_type(
left: DataFrame | LazyFrame, right: DataFrame | LazyFrame
) -> bool:
__tracebackhide__ = True

if isinstance(left, DataFrame) and isinstance(right, DataFrame):
return False
elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame):
Expand All @@ -153,6 +157,8 @@ def _assert_frame_schema_equal(
check_column_order: bool,
objects: str,
) -> None:
__tracebackhide__ = True

left_schema, right_schema = left.schema, right.schema

# Fast path for equal frames
Expand Down Expand Up @@ -253,6 +259,8 @@ def assert_frame_not_equal(
...
AssertionError: frames are equal
"""
__tracebackhide__ = True

try:
assert_frame_equal(
left=left,
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def assert_series_equal(
[left]: [1, 2, 3]
[right]: [1, 5, 3]
"""
__tracebackhide__ = True

if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr]
raise_assertion_error(
"inputs",
Expand Down Expand Up @@ -119,6 +121,8 @@ def _assert_series_values_equal(
atol: float,
categorical_as_str: bool,
) -> None:
__tracebackhide__ = True

"""Assert that the values in both Series are equal."""
# Handle categoricals
if categorical_as_str:
Expand Down Expand Up @@ -191,6 +195,8 @@ def _assert_series_nested_values_equal(
atol: float,
categorical_as_str: bool,
) -> None:
__tracebackhide__ = True

# compare nested lists element-wise
if _comparing_lists(left.dtype, right.dtype):
for s1, s2 in zip(left, right):
Expand Down Expand Up @@ -221,6 +227,7 @@ def _assert_series_nested_values_equal(


def _assert_series_null_values_match(left: Series, right: Series) -> None:
__tracebackhide__ = True
null_value_mismatch = left.is_null() != right.is_null()
if null_value_mismatch.any():
raise_assertion_error(
Expand All @@ -229,6 +236,7 @@ def _assert_series_null_values_match(left: Series, right: Series) -> None:


def _assert_series_nan_values_match(left: Series, right: Series) -> None:
__tracebackhide__ = True
if not _comparing_floats(left.dtype, right.dtype):
return
nan_value_mismatch = left.is_nan() != right.is_nan()
Expand Down Expand Up @@ -270,6 +278,8 @@ def _assert_series_values_within_tolerance(
rtol: float,
atol: float,
) -> None:
__tracebackhide__ = True

left_unequal, right_unequal = left.filter(unequal), right.filter(unequal)

difference = (left_unequal - right_unequal).abs()
Expand Down Expand Up @@ -339,6 +349,8 @@ def assert_series_not_equal(
...
AssertionError: Series are equal
"""
__tracebackhide__ = True

try:
assert_series_equal(
left=left,
Expand Down
64 changes: 64 additions & 0 deletions py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from polars.testing import assert_frame_equal, assert_frame_not_equal

nan = float("nan")
pytest_plugins = ["pytester"]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -366,3 +367,66 @@ def test_assert_frame_not_equal() -> None:
df = pl.DataFrame({"a": [1, 2]})
with pytest.raises(AssertionError, match="frames are equal"):
assert_frame_not_equal(df, df)


def test_tracebackhide(testdir: pytest.Testdir) -> None:
testdir.makefile(
".py",
test_path="""\
import polars as pl
from polars.testing import assert_frame_equal, assert_frame_not_equal

def test_frame_equal_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 3]})
assert_frame_equal(df1, df2)

def test_frame_not_equal_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2]})
assert_frame_not_equal(df1, df2)

def test_frame_data_type_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = {"a": [1, 2]}
assert_frame_equal(df1, df2)

def test_frame_schema_fail():
df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64})
df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32})
assert_frame_equal(df1, df2)
""",
)
result = testdir.runpytest()
result.assert_outcomes(passed=0, failed=4)
stdout = "\n".join(result.outlines)

assert "polars/py-polars/polars/testing" not in stdout

# The above should catch any polars testing functions that appear in the
# stack trace. But we keep the following checks (for specific function
# names) just to double-check.

assert "def assert_frame_equal" not in stdout
assert "def assert_frame_not_equal" not in stdout
assert "def _assert_correct_input_type" not in stdout
assert "def _assert_frame_schema_equal" not in stdout

assert "def assert_series_equal" not in stdout
assert "def assert_series_not_equal" not in stdout
assert "def _assert_series_values_equal" not in stdout
assert "def _assert_series_nested_values_equal" not in stdout
assert "def _assert_series_null_values_match" not in stdout
assert "def _assert_series_nan_values_match" not in stdout
assert "def _assert_series_values_within_tolerance" not in stdout

# Make sure the tests are failing for the expected reason (e.g. not because
# an import is missing or something like that):

assert (
"AssertionError: DataFrames are different (value mismatch for column 'a')"
in stdout
)
assert "AssertionError: frames are equal" in stdout
assert "AssertionError: inputs are different (unexpected input types)" in stdout
assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout
79 changes: 79 additions & 0 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from polars.testing import assert_series_equal, assert_series_not_equal

nan = float("nan")
pytest_plugins = ["pytester"]


def test_compare_series_value_mismatch() -> None:
Expand Down Expand Up @@ -636,3 +637,81 @@ def test_assert_series_equal_w_large_integers_12328() -> None:
right = pl.Series([1577840521123543])
with pytest.raises(AssertionError):
assert_series_equal(left, right)


def test_tracebackhide(testdir: pytest.Testdir) -> None:
testdir.makefile(
".py",
test_path="""\
import polars as pl
from polars.testing import assert_series_equal, assert_series_not_equal

nan = float("nan")

def test_series_equal_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, 3])
assert_series_equal(s1, s2)

def test_series_not_equal_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, 2])
assert_series_not_equal(s1, s2)

def test_series_nested_fail():
s1 = pl.Series([[1, 2], [3, 4]])
s2 = pl.Series([[1, 2], [3, 5]])
assert_series_equal(s1, s2)

def test_series_null_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, None])
assert_series_equal(s1, s2)

def test_series_nan_fail():
s1 = pl.Series([1.0, 2.0])
s2 = pl.Series([1.0, nan])
assert_series_equal(s1, s2)

def test_series_float_tolerance_fail():
s1 = pl.Series([1.0, 2.0])
s2 = pl.Series([1.0, 2.1])
assert_series_equal(s1, s2)

def test_series_schema_fail():
s1 = pl.Series([1, 2], dtype=pl.Int64)
s2 = pl.Series([1, 2], dtype=pl.Int32)
assert_series_equal(s1, s2)

def test_series_data_type_fail():
s1 = pl.Series([1, 2])
s2 = [1, 2]
assert_series_equal(s1, s2)
""",
)
result = testdir.runpytest()
result.assert_outcomes(passed=0, failed=8)
stdout = "\n".join(result.outlines)

assert "polars/py-polars/polars/testing" not in stdout

# The above should catch any polars testing functions that appear in the
# stack trace. But we keep the following checks (for specific function
# names) just to double-check.

assert "def assert_series_equal" not in stdout
assert "def assert_series_not_equal" not in stdout
assert "def _assert_series_values_equal" not in stdout
assert "def _assert_series_nested_values_equal" not in stdout
assert "def _assert_series_null_values_match" not in stdout
assert "def _assert_series_nan_values_match" not in stdout
assert "def _assert_series_values_within_tolerance" not in stdout

# Make sure the tests are failing for the expected reason (e.g. not because
# an import is missing or something like that):

assert "AssertionError: Series are different (exact value mismatch)" in stdout
assert "AssertionError: Series are equal" in stdout
assert "AssertionError: Series are different (nan value mismatch)" in stdout
assert "AssertionError: Series are different (dtype mismatch)" in stdout
assert "AssertionError: inputs are different (unexpected input types)" in stdout
Loading