Skip to content

Commit

Permalink
test: adds test_when_stress, test_when_condition_parity
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Jul 6, 2024
1 parent b110f65 commit f33a9cd
Showing 1 changed file with 185 additions and 0 deletions.
185 changes: 185 additions & 0 deletions tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import operator
import os
import pathlib
import re
import tempfile

import jsonschema
Expand Down Expand Up @@ -78,6 +79,100 @@ def basic_chart():
return alt.Chart(data).mark_bar().encode(x="a", y="b")


@pytest.fixture()
def cars():
return pd.DataFrame(
{
"Name": [
"chevrolet chevelle malibu",
"buick skylark 320",
"plymouth satellite",
"amc rebel sst",
"ford torino",
"ford galaxie 500",
"chevrolet impala",
"plymouth fury iii",
"pontiac catalina",
"amc ambassador dpl",
],
"Miles_per_Gallon": [
18.0,
15.0,
18.0,
16.0,
17.0,
15.0,
14.0,
14.0,
14.0,
15.0,
],
"Cylinders": [8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
"Displacement": [
307.0,
350.0,
318.0,
304.0,
302.0,
429.0,
454.0,
440.0,
455.0,
390.0,
],
"Horsepower": [
130.0,
165.0,
150.0,
150.0,
140.0,
198.0,
220.0,
215.0,
225.0,
190.0,
],
"Weight_in_lbs": [
3504,
3693,
3436,
3433,
3449,
4341,
4354,
4312,
4425,
3850,
],
"Acceleration": [12.0, 11.5, 11.0, 12.0, 10.5, 10.0, 9.0, 8.5, 10.0, 8.5],
"Year": [
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
pd.Timestamp("1970-01-01 00:00:00"),
],
"Origin": [
"USA",
"USA",
"USA",
"USA",
"USA",
"USA",
"USA",
"USA",
"USA",
"USA",
],
}
)


def test_chart_data_types():
def Chart(data):
return alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q")
Expand Down Expand Up @@ -461,6 +556,96 @@ def test_when_expressions_inside_parameters() -> None:
chart.to_dict()


def test_when_stress():
# Triggering structural errors
brush = alt.selection_interval()
select_x = alt.selection_interval(encodings=["x"])
when = alt.when(brush)
reveal_msg = re.compile(r"Only one field.+Shorthand 'max\(\)'", flags=re.DOTALL)
with pytest.raises(TypeError, match=reveal_msg):
when.then("count()").otherwise("max()")

chain_mixed_msg = re.compile(
r"Chained.+mixed.+conflict.+\{'field': 'field_1', 'type': 'quantitative'\}.+otherwise",
flags=re.DOTALL,
)
with pytest.raises(TypeError, match=chain_mixed_msg):
when.then({"field": "field_1", "type": "quantitative"}).when(
select_x, field_2=99
)

with pytest.raises(TypeError, match=chain_mixed_msg):
when.then("field_1:Q").when(Genre="pop")

chain_otherwise_msg = re.compile(
r"Chained.+mixed.+field.+AggregatedFieldDef.+'this_field_here'",
flags=re.DOTALL,
)
with pytest.raises(TypeError, match=chain_otherwise_msg):
when.then(5).when(
alt.selection_point(fields=["b"]) | brush, empty=False, b=63812
).then("min(foo):Q").otherwise(
alt.AggregatedFieldDef(
"argmax", field="field_9", **{"as": "this_field_here"}
)
)


@pytest.mark.parametrize(
("channel", "then", "otherwise"),
[
("color", alt.ColorValue("red"), alt.ColorValue("blue")),
("opacity", alt.value(0.5), alt.value(1.0)),
("text", alt.TextValue("foo"), alt.value("bar")),
("color", alt.Color("col1:N"), alt.value("blue")),
("opacity", "col1:N", alt.value(0.5)),
("text", alt.value("abc"), alt.Text("Name:N")),
("size", alt.value(20), "Name:N"),
("size", "count()", alt.value(0)),
],
)
@pytest.mark.parametrize(
"when",
[
alt.selection_interval(),
alt.selection_point(),
alt.datum.Displacement > alt.value(350),
alt.selection_point(name="select", on="click"),
alt.selection_point(fields=["Horsepower"]),
],
)
@pytest.mark.parametrize("empty", [alt.Undefined, True, False])
def test_when_condition_parity(
cars, channel: str, when, empty: alt.Optional[bool], then, otherwise
):
params = [when] if isinstance(when, alt.Parameter) else ()
kwds = {"x": "Cylinders:N", "y": "Origin:N"}

input_condition = alt.condition(when, then, otherwise, empty=empty)
chart_condition = (
alt.Chart(cars)
.mark_rect()
.encode(**kwds, **{channel: input_condition})
.add_params(*params)
.to_dict()
)

input_when = alt.when(when, empty=empty).then(then).otherwise(otherwise)
chart_when = (
alt.Chart(cars)
.mark_rect()
.encode(**kwds, **{channel: input_when})
.add_params(*params)
.to_dict()
)

if isinstance(input_when["condition"], list):
input_when["condition"] = input_when["condition"][0]
assert input_condition == input_when
else:
assert chart_condition == chart_when


def test_selection_to_dict():
brush = alt.selection_interval()

Expand Down

0 comments on commit f33a9cd

Please sign in to comment.