Skip to content

Commit

Permalink
fix(rust, python): fix invalid access when groupby rolling produces e…
Browse files Browse the repository at this point in the history
…mpty sets (#10109)
  • Loading branch information
ritchie46 authored Jul 27, 2023
1 parent 4bce995 commit ef91c45
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
16 changes: 10 additions & 6 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars_core::series::IsSorted;
use polars_core::utils::ensure_sorted_arg;
use polars_core::utils::flatten::flatten_par;
use polars_core::POOL;
use polars_utils::slice::SortedSlice;
use polars_utils::slice::{GetSaferUnchecked, SortedSlice};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use smartstring::alias::String as SmartString;
Expand Down Expand Up @@ -638,15 +638,19 @@ fn update_subgroups_idx(
sub_groups
.iter()
.map(|&[first, len]| {
let new_first = unsafe { *base_g.1.get_unchecked(first as usize) };
let new_first = if len == 0 {
// in case the group is empty
// keep the original first so that the
// groupby keys still point to the original group
base_g.0
} else {
unsafe { *base_g.1.get_unchecked_release(first as usize) }
};

let first = first as usize;
let len = len as usize;
let idx = (first..first + len)
.map(|i| {
debug_assert!(i < base_g.1.len());
unsafe { *base_g.1.get_unchecked(i) }
})
.map(|i| unsafe { *base_g.1.get_unchecked_release(i) })
.collect_trusted::<Vec<_>>();
(new_first, idx)
})
Expand Down
41 changes: 40 additions & 1 deletion py-polars/tests/unit/operations/test_groupby_rolling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from datetime import datetime
from datetime import date, datetime
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -258,3 +258,42 @@ def test_groupby_rolling_dynamic_sortedness_check() -> None:
match=r"argument in operation 'groupby_rolling' is not explicitly sorted",
):
df.groupby_rolling("idx", period="2i").agg(pl.col("idx").alias("idx1"))


def test_groupby_rolling_empty_groups_9973() -> None:
dt1 = date(2001, 1, 1)
dt2 = date(2001, 1, 2)

data = pl.DataFrame(
{
"id": ["A", "A", "B", "B", "C", "C"],
"date": [dt1, dt2, dt1, dt2, dt1, dt2],
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}
).sort(by=["id", "date"])

expected = pl.DataFrame(
{
"id": ["A", "A", "B", "B", "C", "C"],
"date": [
date(2001, 1, 1),
date(2001, 1, 2),
date(2001, 1, 1),
date(2001, 1, 2),
date(2001, 1, 1),
date(2001, 1, 2),
],
"value": [[2.0], [], [4.0], [], [6.0], []],
}
)

out = data.groupby_rolling(
index_column="date",
by="id",
period="2d",
offset="1d",
closed="left",
check_sorted=True,
).agg(pl.col("value"))

assert_frame_equal(out, expected)

0 comments on commit ef91c45

Please sign in to comment.