Skip to content

Commit

Permalink
reduce number of fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 1, 2023
1 parent 0d6e299 commit 2329aff
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 31 deletions.
29 changes: 9 additions & 20 deletions tests/expressions/typing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,6 @@
ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2))


@pytest.fixture(
scope="module",
params=ALL_DATATYPES_BINARY_PAIRS,
ids=[f"{dt1}-{dt2}" for (dt1, _), (dt2, _) in ALL_DATATYPES_BINARY_PAIRS],
)
def binary_data_fixture(request) -> tuple[Series, Series]:
"""Returns binary permutation of Series' of all DataType pairs"""
(dt1, data1), (dt2, data2) = request.param
s1 = Series.from_arrow(data1, name="lhs")
assert s1.datatype() == dt1
s2 = Series.from_arrow(data2, name="rhs")
assert s2.datatype() == dt2
return (s1, s2)


ALL_TEMPORAL_DTYPES = [
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
*[
Expand Down Expand Up @@ -94,6 +79,8 @@ def binary_data_fixture(request) -> tuple[Series, Series]:
],
]

ALL_DTYPES += ALL_TEMPORAL_DTYPES

ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [
((dt1, data1), (dt2, data2))
for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2)
Expand All @@ -104,13 +91,15 @@ def binary_data_fixture(request) -> tuple[Series, Series]:
)
]

ALL_DATATYPES_BINARY_PAIRS += ALL_TEMPORAL_DATATYPES_BINARY_PAIRS


@pytest.fixture(
scope="module",
params=ALL_TEMPORAL_DATATYPES_BINARY_PAIRS,
ids=[f"{dt1}-{dt2}" for (dt1, _), (dt2, _) in ALL_TEMPORAL_DATATYPES_BINARY_PAIRS],
params=ALL_DATATYPES_BINARY_PAIRS,
ids=[f"{dt1}-{dt2}" for (dt1, _), (dt2, _) in ALL_DATATYPES_BINARY_PAIRS],
)
def binary_temporal_data_fixture(request) -> tuple[Series, Series]:
def binary_data_fixture(request) -> tuple[Series, Series]:
"""Returns binary permutation of Series' of all DataType pairs"""
(dt1, data1), (dt2, data2) = request.param
s1 = Series.from_arrow(data1, name="lhs")
Expand All @@ -122,8 +111,8 @@ def binary_temporal_data_fixture(request) -> tuple[Series, Series]:

@pytest.fixture(
scope="module",
params=ALL_DTYPES + ALL_TEMPORAL_DTYPES,
ids=[f"{dt}" for (dt, _) in ALL_DTYPES + ALL_TEMPORAL_DTYPES],
params=ALL_DTYPES,
ids=[f"{dt}" for (dt, _) in ALL_DTYPES],
)
def unary_data_fixture(request) -> Series:
"""Returns unary permutation of Series' of all DataType pairs"""
Expand Down
11 changes: 0 additions & 11 deletions tests/expressions/typing/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,3 @@ def test_comparable(binary_data_fixture, op):
run_kernel=lambda: op(lhs, rhs),
resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()),
)


@pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge])
def test_temporal_comparable(binary_temporal_data_fixture, op):
lhs, rhs = binary_temporal_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=binary_temporal_data_fixture,
expr=op(col(lhs.name()), col(rhs.name())),
run_kernel=lambda: op(lhs, rhs),
resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()),
)

0 comments on commit 2329aff

Please sign in to comment.