Skip to content

Commit

Permalink
fix: pivot was producing incorrect results when (single) index was …
Browse files Browse the repository at this point in the history
…Struct (#14308)
  • Loading branch information
MarcoGorelli authored Feb 12, 2024
1 parent 451f293 commit 649c33a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
50 changes: 50 additions & 0 deletions crates/polars-ops/src/frame/pivot/positioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,51 @@ where
(row_locations, idx as usize, row_index)
}

fn compute_row_index_struct(
index: &[String],
index_agg: &Series,
index_agg_physical: &BinaryOffsetChunked,
count: usize,
) -> (Vec<IdxSize>, usize, Option<Vec<Series>>) {
let mut row_to_idx =
PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default());
let mut idx = 0 as IdxSize;

let mut row_locations = Vec::with_capacity(index_agg_physical.len());
let mut unique_indices = Vec::with_capacity(index_agg_physical.len());
let mut row_number: IdxSize = 0;
for arr in index_agg_physical.downcast_iter() {
for opt_v in arr.iter() {
let idx = *row_to_idx.entry(opt_v).or_insert_with(|| {
// SAFETY: we pre-allocated
unsafe { unique_indices.push_unchecked(row_number) };
let old_idx = idx;
idx += 1;
old_idx
});
row_number += 1;

// SAFETY:
// we pre-allocated
unsafe {
row_locations.push_unchecked(idx);
}
}
}
let row_index = match count {
0 => {
// SAFETY: `unique_indices` is filled with elements between
// 0 and `index_agg.len() - 1`.
let mut s = unsafe { index_agg.take_slice_unchecked(&unique_indices) };
s.rename(&index[0]);
Some(vec![s])
},
_ => None,
};

(row_locations, idx as usize, row_index)
}

// TODO! Also create a specialized version for numerics.
pub(super) fn compute_row_idx(
pivot_df: &DataFrame,
Expand All @@ -353,6 +398,11 @@ pub(super) fn compute_row_idx(
let ca = index_agg_physical.bool().unwrap();
compute_row_index(index, ca, count, index_s.dtype())
},
Struct(_) => {
let ca = index_agg_physical.struct_().unwrap();
let ca = ca.rows_encode()?;
compute_row_index_struct(index, &index_agg, &ca, count)
},
String => {
let ca = index_agg_physical.str().unwrap();
compute_row_index(index, ca, count, index_s.dtype())
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,20 @@ def test_pivot_struct_13120() -> None:
assert result == expected


def test_pivot_index_struct_14101() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 1],
"b": [{"a": 1}, {"a": 1}, {"a": 2}],
"c": ["x", "y", "y"],
"d": [1, 1, 3],
}
)
result = df.pivot(index="b", values="a", columns="c")
expected = pl.DataFrame({"b": [{"a": 1}, {"a": 2}], "x": [1, None], "y": [2, 1]})
assert_frame_equal(result, expected)


def test_pivot_name_already_exists() -> None:
# This should be extremely rare...but still, good to check it
df = pl.DataFrame(
Expand Down

0 comments on commit 649c33a

Please sign in to comment.