diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f34484ea6..596e041846 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Ensure scatter `mode` is deterministic from `px` [[#4429](https://github.com/plotly/plotly.py/pull/4429)] - Fix issue with creating dendrogram in subplots [[#4411](https://github.com/plotly/plotly.py/pull/4411)], - Fix issue with px.line not accepting "spline" line shape [[#2812](https://github.com/plotly/plotly.py/issues/2812)] +- Fix KeyError when using column of `pd.Categorical` dtype with unobserved categories [[#4437](https://github.com/plotly/plotly.py/pull/4437)] ## [5.18.0] - 2023-10-25 diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b80355c727..037123a37e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2042,7 +2042,9 @@ def get_groups_and_orders(args, grouper): groups = {tuple(single_group_name): df} else: required_grouper = [g for g in grouper if g != one_group] - grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers + grouped = df.groupby( + required_grouper, sort=False, observed=True + ) # skip one_group groupers group_indices = grouped.indices sorted_group_names = [ g if len(required_grouper) != 1 else (g,) for g in group_indices diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_colors.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_colors.py index 85e1dc00bf..c57b1b3898 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_colors.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_colors.py @@ -53,3 +53,11 @@ def test_r_colorscales(): assert scale.replace("_r", "") in scale_names else: assert scale + "_r" in scale_names + + +def test_color_categorical_dtype(): + df = px.data.tips() + df["day"] = df["day"].astype("category") + px.scatter( + df[df.day != df.day.cat.categories[0]], x="total_bill", y="tip", color="day" + )