From fd881fd50c3515a594b70a06a0cb8f8fd7554925 Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 27 Jul 2023 09:46:09 +0200 Subject: [PATCH] fix(rust, python): fix invalid access when groupby rolling produces empty sets --- polars/polars-time/src/groupby/dynamic.rs | 16 +++++--- .../unit/operations/test_groupby_rolling.py | 41 ++++++++++++++++++- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/polars/polars-time/src/groupby/dynamic.rs b/polars/polars-time/src/groupby/dynamic.rs index b8592c0cff21..27ad69450446 100644 --- a/polars/polars-time/src/groupby/dynamic.rs +++ b/polars/polars-time/src/groupby/dynamic.rs @@ -7,7 +7,7 @@ use polars_core::series::IsSorted; use polars_core::utils::ensure_sorted_arg; use polars_core::utils::flatten::flatten_par; use polars_core::POOL; -use polars_utils::slice::SortedSlice; +use polars_utils::slice::{GetSaferUnchecked, SortedSlice}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; @@ -638,15 +638,19 @@ fn update_subgroups_idx( sub_groups .iter() .map(|&[first, len]| { - let new_first = unsafe { *base_g.1.get_unchecked(first as usize) }; + let new_first = if len == 0 { + // in case the group is empty + // keep the original first so that the + // groupby keys still point to the original group + base_g.0 + } else { + unsafe { *base_g.1.get_unchecked_release(first as usize) } + }; let first = first as usize; let len = len as usize; let idx = (first..first + len) - .map(|i| { - debug_assert!(i < base_g.1.len()); - unsafe { *base_g.1.get_unchecked(i) } - }) + .map(|i| unsafe { *base_g.1.get_unchecked_release(i) }) .collect_trusted::>(); (new_first, idx) }) diff --git a/py-polars/tests/unit/operations/test_groupby_rolling.py b/py-polars/tests/unit/operations/test_groupby_rolling.py index f628ec27cfcd..d06bdbed75ec 100644 --- a/py-polars/tests/unit/operations/test_groupby_rolling.py +++ b/py-polars/tests/unit/operations/test_groupby_rolling.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import date, datetime from typing import TYPE_CHECKING, Any import pytest @@ -258,3 +258,42 @@ def test_groupby_rolling_dynamic_sortedness_check() -> None: match=r"argument in operation 'groupby_rolling' is not explicitly sorted", ): df.groupby_rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) + + +def test_groupby_rolling_empty_groups_9973() -> None: + dt1 = date(2001, 1, 1) + dt2 = date(2001, 1, 2) + + data = pl.DataFrame( + { + "id": ["A", "A", "B", "B", "C", "C"], + "date": [dt1, dt2, dt1, dt2, dt1, dt2], + "value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + ).sort(by=["id", "date"]) + + expected = pl.DataFrame( + { + "id": ["A", "A", "B", "B", "C", "C"], + "date": [ + date(2001, 1, 1), + date(2001, 1, 2), + date(2001, 1, 1), + date(2001, 1, 2), + date(2001, 1, 1), + date(2001, 1, 2), + ], + "value": [[2.0], [], [4.0], [], [6.0], []], + } + ) + + out = data.groupby_rolling( + index_column="date", + by="id", + period="2d", + offset="1d", + closed="left", + check_sorted=True, + ).agg(pl.col("value")) + + assert_frame_equal(out, expected)