diff --git a/altair/vegalite/v3/api.py b/altair/vegalite/v3/api.py index 2b372f357..ceef45ef9 100644 --- a/altair/vegalite/v3/api.py +++ b/altair/vegalite/v3/api.py @@ -1882,8 +1882,7 @@ def add_selection(self, *selections): if not selections or not self.layer: return self copy = self.copy() - copy.layer = [chart.add_selection(*selections) - for chart in copy.layer] + copy.layer[0] = copy.layer[0].add_selection(*selections) return copy diff --git a/altair/vegalite/v3/tests/test_api.py b/altair/vegalite/v3/tests/test_api.py index caaa9fc81..8400b3c6e 100644 --- a/altair/vegalite/v3/tests/test_api.py +++ b/altair/vegalite/v3/tests/test_api.py @@ -535,7 +535,15 @@ def test_facet_add_selections(): assert chart1.to_dict() == chart2.to_dict() -@pytest.mark.parametrize('charttype', [alt.layer, alt.concat, alt.hconcat, alt.vconcat]) +def test_layer_add_selection(): + base = alt.Chart('data.csv').mark_point() + selection = alt.selection_single() + chart1 = alt.layer(base.add_selection(selection), base) + chart2 = alt.layer(base, base).add_selection(selection) + assert chart1.to_dict() == chart2.to_dict() + + +@pytest.mark.parametrize('charttype', [alt.concat, alt.hconcat, alt.vconcat]) def test_compound_add_selections(charttype): base = alt.Chart('data.csv').mark_point() selection = alt.selection_single()