Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): fix invalid access when groupby rolling produces empty sets #10109

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars_core::series::IsSorted;
use polars_core::utils::ensure_sorted_arg;
use polars_core::utils::flatten::flatten_par;
use polars_core::POOL;
use polars_utils::slice::SortedSlice;
use polars_utils::slice::{GetSaferUnchecked, SortedSlice};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use smartstring::alias::String as SmartString;
Expand Down Expand Up @@ -638,15 +638,19 @@ fn update_subgroups_idx(
sub_groups
.iter()
.map(|&[first, len]| {
let new_first = unsafe { *base_g.1.get_unchecked(first as usize) };
let new_first = if len == 0 {
// in case the group is empty
// keep the original first so that the
// groupby keys still point to the original group
base_g.0
} else {
unsafe { *base_g.1.get_unchecked_release(first as usize) }
};

let first = first as usize;
let len = len as usize;
let idx = (first..first + len)
.map(|i| {
debug_assert!(i < base_g.1.len());
unsafe { *base_g.1.get_unchecked(i) }
})
.map(|i| unsafe { *base_g.1.get_unchecked_release(i) })
.collect_trusted::<Vec<_>>();
(new_first, idx)
})
Expand Down
41 changes: 40 additions & 1 deletion py-polars/tests/unit/operations/test_groupby_rolling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from datetime import datetime
from datetime import date, datetime
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -258,3 +258,42 @@ def test_groupby_rolling_dynamic_sortedness_check() -> None:
match=r"argument in operation 'groupby_rolling' is not explicitly sorted",
):
df.groupby_rolling("idx", period="2i").agg(pl.col("idx").alias("idx1"))


def test_groupby_rolling_empty_groups_9973() -> None:
dt1 = date(2001, 1, 1)
dt2 = date(2001, 1, 2)

data = pl.DataFrame(
{
"id": ["A", "A", "B", "B", "C", "C"],
"date": [dt1, dt2, dt1, dt2, dt1, dt2],
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}
).sort(by=["id", "date"])

expected = pl.DataFrame(
{
"id": ["A", "A", "B", "B", "C", "C"],
"date": [
date(2001, 1, 1),
date(2001, 1, 2),
date(2001, 1, 1),
date(2001, 1, 2),
date(2001, 1, 1),
date(2001, 1, 2),
],
"value": [[2.0], [], [4.0], [], [6.0], []],
}
)

out = data.groupby_rolling(
index_column="date",
by="id",
period="2d",
offset="1d",
closed="left",
check_sorted=True,
).agg(pl.col("value"))

assert_frame_equal(out, expected)
Loading