Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extrapolate_flat parameter in interpolate_by #18355

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 147 additions & 55 deletions crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ fn interpolate_impl_by_sorted<T, F, I>(
chunked_arr: &ChunkedArray<T>,
by: &ChunkedArray<F>,
interpolation_branch: I,
extrapolate_flat: bool,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
Expand All @@ -104,11 +105,19 @@ where
let first = chunked_arr.first_non_null().unwrap();
let last = chunked_arr.last_non_null().unwrap() + 1;

// Get lowest and highest value in case we are going to extrapolate
let lowest_value = chunked_arr.get(first).unwrap();
let highest_value = chunked_arr.get(last - 1).unwrap();

// Fill out with `first` nulls.
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().enumerate().skip(first);
for _ in 0..first {
out.push(Zero::zero());
if extrapolate_flat {
out.push(lowest_value)
} else {
out.push(Zero::zero());
}
}

// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first
Expand Down Expand Up @@ -141,13 +150,20 @@ where
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
validity.extend_constant(chunked_arr.len(), true);

for i in 0..first {
unsafe { validity.set_unchecked(i, false) };
// If we're not extrapolating, set the null mask for the area before
if !extrapolate_flat {
for i in 0..first {
unsafe { validity.set_unchecked(i, false) };
}
}

for i in last..chunked_arr.len() {
unsafe { validity.set_unchecked(i, false) }
out.push(Zero::zero());
if extrapolate_flat {
out.push(highest_value);
} else {
unsafe { validity.set_unchecked(i, false) }
out.push(Zero::zero());
}
}

let array = PrimitiveArray::new(
Expand All @@ -166,6 +182,7 @@ fn interpolate_impl_by<T, F, I>(
ca: &ChunkedArray<T>,
by: &ChunkedArray<F>,
interpolation_branch: I,
extrapolate_flat: bool,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
Expand Down Expand Up @@ -193,6 +210,10 @@ where
let first = ca_sorted.first_non_null().unwrap();
let last = ca_sorted.last_non_null().unwrap() + 1;

// Get lowest and highest value in case we are going to extrapolate
let lowest_value = ca_sorted.get(first).unwrap();
let highest_value = ca_sorted.get(last - 1).unwrap();

let mut out = zeroed_vec(ca_sorted.len());
let mut iter = ca_sorted.iter().enumerate().skip(first);

Expand Down Expand Up @@ -241,14 +262,22 @@ where
for i in 0..first {
unsafe {
let out_idx = sorting_indices.get_unchecked(i);
validity.set_unchecked(*out_idx as usize, false);
if extrapolate_flat {
*out.get_unchecked_mut(*out_idx as usize) = lowest_value;
} else {
validity.set_unchecked(*out_idx as usize, false);
}
}
}

for i in last..ca_sorted.len() {
unsafe {
let out_idx = sorting_indices.get_unchecked(i);
validity.set_unchecked(*out_idx as usize, false);
if extrapolate_flat {
*out.get_unchecked_mut(*out_idx as usize) = highest_value;
} else {
validity.set_unchecked(*out_idx as usize, false);
}
}
}

Expand All @@ -263,77 +292,140 @@ where
}
}

pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResult<Series> {
pub fn interpolate_by(
s: &Series,
by: &Series,
by_is_sorted: bool,
extrapolate_flat: bool,
) -> PolarsResult<Series> {
polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len());

fn func<T, F>(
ca: &ChunkedArray<T>,
by: &ChunkedArray<F>,
is_sorted: bool,
extrapolate_flat: bool,
) -> PolarsResult<Series>
where
T: PolarsNumericType,
F: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
if is_sorted {
interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe {
signed_interp_by_sorted(y_start, y_end, x, out)
})
interpolate_impl_by_sorted(
ca,
by,
|y_start, y_end, x, out| unsafe { signed_interp_by_sorted(y_start, y_end, x, out) },
extrapolate_flat,
)
.map(|x| x.into_series())
} else {
interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe {
signed_interp_by(y_start, y_end, x, out, sorting_indices)
})
interpolate_impl_by(
ca,
by,
|y_start, y_end, x, out, sorting_indices| unsafe {
signed_interp_by(y_start, y_end, x, out, sorting_indices)
},
extrapolate_flat,
)
.map(|x| x.into_series())
}
}

match (s.dtype(), by.dtype()) {
(DataType::Float64, DataType::Float64) => {
func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Float32) => {
func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float64) => {
func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float32) => {
func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int64) => {
func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int32) => {
func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::UInt64) => {
func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::UInt32) => {
func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Int64) => {
func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Int32) => {
func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::UInt64) => {
func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::UInt32) => {
func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Float64) => func(
s.f64().unwrap(),
by.f64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float64, DataType::Float32) => func(
s.f64().unwrap(),
by.f32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::Float64) => func(
s.f32().unwrap(),
by.f64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::Float32) => func(
s.f32().unwrap(),
by.f32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float64, DataType::Int64) => func(
s.f64().unwrap(),
by.i64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float64, DataType::Int32) => func(
s.f64().unwrap(),
by.i32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float64, DataType::UInt64) => func(
s.f64().unwrap(),
by.u64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float64, DataType::UInt32) => func(
s.f64().unwrap(),
by.u32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::Int64) => func(
s.f32().unwrap(),
by.i64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::Int32) => func(
s.f32().unwrap(),
by.i32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::UInt64) => func(
s.f32().unwrap(),
by.u64().unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::Float32, DataType::UInt32) => func(
s.f32().unwrap(),
by.u32().unwrap(),
by_is_sorted,
extrapolate_flat,
),
#[cfg(feature = "dtype-date")]
(_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted),
(_, DataType::Date) => interpolate_by(
s,
&by.cast(&DataType::Int32).unwrap(),
by_is_sorted,
extrapolate_flat,
),
#[cfg(feature = "dtype-datetime")]
(_, DataType::Datetime(_, _)) => {
interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted)
},
(_, DataType::Datetime(_, _)) => interpolate_by(
s,
&by.cast(&DataType::Int64).unwrap(),
by_is_sorted,
extrapolate_flat,
),
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted)
interpolate_by(
&s.cast(&DataType::Float64).unwrap(),
by,
by_is_sorted,
extrapolate_flat,
)
},
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ pub(super) fn interpolate(s: &Series, method: InterpolationMethod) -> PolarsResu
}

#[cfg(feature = "interpolate_by")]
pub(super) fn interpolate_by(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn interpolate_by(s: &[Series], extrapolate_flat: bool) -> PolarsResult<Series> {
let by = &s[1];
let by_is_sorted = by.is_sorted(Default::default())?;
polars_ops::prelude::interpolate_by(&s[0], by, by_is_sorted)
polars_ops::prelude::interpolate_by(&s[0], by, by_is_sorted, extrapolate_flat)
}

pub(super) fn to_physical(s: &Series) -> PolarsResult<Series> {
Expand Down
13 changes: 8 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ pub enum FunctionExpr {
#[cfg(feature = "interpolate")]
Interpolate(InterpolationMethod),
#[cfg(feature = "interpolate_by")]
InterpolateBy,
InterpolateBy {
extrapolate_flat: bool,
},
#[cfg(feature = "log")]
Entropy {
base: f64,
Expand Down Expand Up @@ -395,7 +397,7 @@ impl Hash for FunctionExpr {
#[cfg(feature = "interpolate")]
Interpolate(f) => f.hash(state),
#[cfg(feature = "interpolate_by")]
InterpolateBy => {},
InterpolateBy { extrapolate_flat } => extrapolate_flat.hash(state),
#[cfg(feature = "ffi_plugin")]
FfiPlugin {
lib,
Expand Down Expand Up @@ -687,7 +689,7 @@ impl Display for FunctionExpr {
#[cfg(feature = "interpolate")]
Interpolate(_) => "interpolate",
#[cfg(feature = "interpolate_by")]
InterpolateBy => "interpolate_by",
InterpolateBy { .. } => "interpolate_by",
#[cfg(feature = "log")]
Entropy { .. } => "entropy",
#[cfg(feature = "log")]
Expand Down Expand Up @@ -1030,8 +1032,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
map!(dispatch::interpolate, method)
},
#[cfg(feature = "interpolate_by")]
InterpolateBy => {
map_as_slice!(dispatch::interpolate_by)
InterpolateBy { extrapolate_flat } => {
map_as_slice!(dispatch::interpolate_by, extrapolate_flat)
//map_as_slice!(dispatch::interpolate_by, extrapolate_flat)
Comment on lines +1036 to +1037
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented out

},
#[cfg(feature = "log")]
Entropy { base, normalize } => map!(log::entropy, base, normalize),
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ impl FunctionExpr {
InterpolationMethod::Nearest => mapper.with_same_dtype(),
},
#[cfg(feature = "interpolate_by")]
InterpolateBy => mapper.map_numeric_to_float_dtype(),
InterpolateBy {
extrapolate_flat: _,
} => mapper.map_numeric_to_float_dtype(),
ShrinkType => {
// we return the smallest type this can return
// this might not be correct once the actual data
Expand Down
9 changes: 7 additions & 2 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,13 @@ impl Expr {

#[cfg(feature = "interpolate_by")]
/// Fill null values using interpolation.
pub fn interpolate_by(self, by: Expr) -> Expr {
self.apply_many_private(FunctionExpr::InterpolateBy, &[by], false, false)
pub fn interpolate_by(self, by: Expr, extrapolate_flat: bool) -> Expr {
self.apply_many_private(
FunctionExpr::InterpolateBy { extrapolate_flat },
&[by],
false,
false,
)
}

#[cfg(feature = "rolling_window")]
Expand Down
7 changes: 5 additions & 2 deletions crates/polars-python/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,11 @@ impl PyExpr {
fn interpolate(&self, method: Wrap<InterpolationMethod>) -> Self {
self.inner.clone().interpolate(method.0).into()
}
fn interpolate_by(&self, by: PyExpr) -> Self {
self.inner.clone().interpolate_by(by.inner).into()
fn interpolate_by(&self, by: PyExpr, extrapolate_flat: bool) -> Self {
self.inner
.clone()
.interpolate_by(by.inner, extrapolate_flat)
.into()
}

fn lower_bound(&self) -> Self {
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
},
)
.to_object(py),
FunctionExpr::InterpolateBy => ("interpolate_by",).to_object(py),
FunctionExpr::InterpolateBy { extrapolate_flat } => {
("interpolate_by", extrapolate_flat).to_object(py)
},
FunctionExpr::Entropy { base, normalize } => {
("entropy", base, normalize).to_object(py)
},
Expand Down
Loading