diff --git a/plotly_express/_core.py b/plotly_express/_core.py index e5a9748..9b956e9 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -756,24 +756,27 @@ def infer_config(args, constructor, trace_patch): return trace_specs, grouped_mappings, sizeref, color_range -def make_figure(args, constructor, trace_patch={}, layout_patch={}): - apply_default_cascade(args) - trace_specs, grouped_mappings, sizeref, color_range = infer_config( - args, constructor, trace_patch - ) - grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] - grouped = args["data_frame"].groupby(grouper, sort=False) +def get_orderings(args, grouper, grouped): + """ + `orders` is the user-supplied ordering (with the remaining data-frame-supplied + ordering appended if the column is used for grouping) + `group_names` is the set of groups, ordered by the order above + """ orders = {} if "category_orders" not in args else args["category_orders"].copy() group_names = [] for group_name in grouped.groups: if len(grouper) == 1: group_name = (group_name,) group_names.append(group_name) - for col, val in zip(grouper, group_name): - if col not in orders: - orders[col] = [] - if val not in orders[col]: - orders[col].append(val) + for col in grouper: + if col != one_group: + uniques = args["data_frame"][col].unique() + if col not in orders: + orders[col] = list(uniques) + else: + for val in uniques: + if val not in orders[col]: + orders[col].append(val) for i, col in reversed(list(enumerate(grouper))): if col != one_group: @@ -782,10 +785,23 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, ) + return orders, group_names + + +def make_figure(args, constructor, trace_patch={}, layout_patch={}): + apply_default_cascade(args) + trace_specs, grouped_mappings, sizeref, color_range = infer_config( + args, constructor, trace_patch + ) + grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] + grouped = args["data_frame"].groupby(grouper, sort=False) + + orders, sorted_group_names = get_orderings(args, grouper, grouped) + trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] - for group_name in group_names: + for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() trace_name_labels = OrderedDict()