Skip to content
This repository has been archived by the owner on Jun 3, 2024. It is now read-only.

Commit

Permalink
fix groupby ordering when multi-grouping, should address #23
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed May 13, 2019
1 parent 9a64419 commit 2d327d9
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions plotly_express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,24 +756,27 @@ def infer_config(args, constructor, trace_patch):
return trace_specs, grouped_mappings, sizeref, color_range


def make_figure(args, constructor, trace_patch={}, layout_patch={}):
apply_default_cascade(args)
trace_specs, grouped_mappings, sizeref, color_range = infer_config(
args, constructor, trace_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)
def get_orderings(args, grouper, grouped):
"""
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
ordering appended if the column is used for grouping)
`group_names` is the set of groups, ordered by the order above
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()
group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
group_names.append(group_name)
for col, val in zip(grouper, group_name):
if col not in orders:
orders[col] = []
if val not in orders[col]:
orders[col].append(val)
for col in grouper:
if col != one_group:
uniques = args["data_frame"][col].unique()
if col not in orders:
orders[col] = list(uniques)
else:
for val in uniques:
if val not in orders[col]:
orders[col].append(val)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
Expand All @@ -782,10 +785,23 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)

return orders, group_names


def make_figure(args, constructor, trace_patch={}, layout_patch={}):
apply_default_cascade(args)
trace_specs, grouped_mappings, sizeref, color_range = infer_config(
args, constructor, trace_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)

trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
for group_name in group_names:
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
Expand Down

0 comments on commit 2d327d9

Please sign in to comment.