Skip to content

Commit

Permalink
[FEAT] Implement str.substr Expression (#2269)
Browse files Browse the repository at this point in the history
Adds `Expressions.str.substr` implementation. 
Resolves issue: #1934
  • Loading branch information
danila-b authored Jun 3, 2024
1 parent 55b0bc4 commit 9a9e52e
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 5 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ class PyExpr:
def utf8_repeat(self, n: PyExpr) -> PyExpr: ...
def utf8_like(self, pattern: PyExpr) -> PyExpr: ...
def utf8_ilike(self, pattern: PyExpr) -> PyExpr: ...
def utf8_substr(self, start: PyExpr, length: PyExpr) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -1161,6 +1162,7 @@ class PySeries:
def utf8_repeat(self, n: PySeries) -> PySeries: ...
def utf8_like(self, pattern: PySeries) -> PySeries: ...
def utf8_ilike(self, pattern: PySeries) -> PySeries: ...
def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
16 changes: 16 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,22 @@ def ilike(self, pattern: str | Expression) -> Expression:
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_ilike(pattern_expr._expr))

def substr(self, start: int | Expression, length: int | Expression | None = None) -> Expression:
"""Extract a substring from a string, starting at a specified index and extending for a given length.
.. NOTE::
If `length` is not provided, the substring will include all characters from `start` to the end of the string.
Example:
>>> col("x").str.substr(2, 2)
Returns:
Expression: A String expression representing the extracted substring.
"""
start_expr = Expression._to_expression(start)
length_expr = Expression._to_expression(length)
return Expression._from_pyexpr(self._expr.utf8_substr(start_expr._expr, length_expr._expr))


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
11 changes: 11 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,17 @@ def ilike(self, pattern: Series) -> Series:
assert self._series is not None and pattern._series is not None
return Series._from_pyseries(self._series.utf8_ilike(pattern._series))

def substr(self, start: Series, length: Series | None = None) -> Series:
if not isinstance(start, Series):
raise ValueError(f"expected another Series but got {type(start)}")
if length is not None and not isinstance(length, Series):
raise ValueError(f"expected another Series but got {type(length)}")
if length is None:
length = Series.from_arrow(pa.array([None]))

assert self._series is not None and start._series is not None
return Series._from_pyseries(self._series.utf8_substr(start._series, length._series))


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.repeat
Expression.str.like
Expression.str.ilike
Expression.str.substr

.. _api-expressions-temporal:

Expand Down
191 changes: 189 additions & 2 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
borrow::Cow,
iter::{Repeat, Take},
iter::{self, Repeat, Take},
};

use crate::{
Expand All @@ -12,7 +12,6 @@ use crate::{
DataType, Series,
};
use arrow2::array::Array;

use common_error::{DaftError, DaftResult};
use itertools::Itertools;
use num_traits::NumCast;
Expand Down Expand Up @@ -271,6 +270,75 @@ fn replace_on_literal<'a>(
Ok(Utf8Array::from((name, Box::new(arrow_result?))))
}

fn substring(s: &str, start: usize, len: Option<usize>) -> Option<&str> {
let mut char_indices = s.char_indices();

if let Some((start_pos, _)) = char_indices.nth(start) {
let len = match len {
Some(len) => {
if len == 0 {
return None;
} else {
len
}
}
None => {
return Some(&s[start_pos..]);
}
};

let end_pos = char_indices
.nth(len.saturating_sub(1))
.map_or(s.len(), |(idx, _)| idx);

Some(&s[start_pos..end_pos])
} else {
None
}
}

fn substr_compute_result<I, U, E, R>(
name: &str,
iter: BroadcastedStrIter,
start: I,
length: U,
) -> DaftResult<Utf8Array>
where
I: Iterator<Item = Result<Option<usize>, E>>,
U: Iterator<Item = Result<Option<usize>, R>>,
{
let arrow_result = iter
.zip(start)
.zip(length)
.map(|((val, s), l)| {
let s = match s {
Ok(s) => s,
Err(_) => {
return Err(DaftError::ComputeError(
"Error in repeat: failed to cast length as usize".to_string(),
))
}
};
let l = match l {
Ok(l) => l,
Err(_) => {
return Err(DaftError::ComputeError(
"Error in repeat: failed to cast length as usize".to_string(),
))
}
};

match (val, s, l) {
(Some(val), Some(s), Some(l)) => Ok(substring(val, s, Some(l))),
(Some(val), Some(s), None) => Ok(substring(val, s, None)),
_ => Ok(None),
}
})
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>()?;

Ok(Utf8Array::from((name, Box::new(arrow_result))))
}

#[derive(Debug, Clone, Copy)]
pub enum PadPlacement {
Left,
Expand Down Expand Up @@ -879,6 +947,125 @@ impl Utf8Array {
Ok(result)
}

pub fn substr<I, J>(
&self,
start: &DataArray<I>,
length: Option<&DataArray<J>>,
) -> DaftResult<Utf8Array>
where
I: DaftIntegerType,
<I as DaftNumericType>::Native: Ord,
J: DaftIntegerType,
<J as DaftNumericType>::Native: Ord,
{
let name = self.name();
let (is_full_null, expected_size) = parse_inputs(self, &[start])
.map_err(|e| DaftError::ValueError(format!("Error in substr: {e}")))?;

if is_full_null {
return Ok(Utf8Array::full_null(name, &DataType::Utf8, expected_size));
}

let self_iter = create_broadcasted_str_iter(self, expected_size);

let (length_repeat, length_iter) = match length {
Some(length) => {
if length.len() != 1 && length.len() != expected_size {
return Err(DaftError::ValueError(
"Inputs have invalid lengths: length".to_string(),
));
}

match length.len() {
1 => {
let length_repeat: Result<Option<usize>, ()> = if length.null_count() == 1 {
Ok(None)
} else {
let val = length.get(0).unwrap();
let val: usize = NumCast::from(val).ok_or_else(|| {
DaftError::ComputeError(format!(
"Error in substr: failed to cast length as usize {val}"
))
})?;

Ok(Some(val))
};

let length_repeat = iter::repeat(length_repeat).take(expected_size);
(Some(length_repeat), None)
}
_ => {
let length_iter = length.as_arrow().iter().map(|l| match l {
Some(l) => {
let l: usize = NumCast::from(*l).ok_or_else(|| {
DaftError::ComputeError(format!(
"Error in repeat: failed to cast length as usize {l}"
))
})?;
let result: Result<Option<usize>, DaftError> = Ok(Some(l));
result
}
None => Ok(None),
});
(None, Some(length_iter))
}
}
}
None => {
let none_value_iter = iter::repeat(Ok(None)).take(expected_size);
(Some(none_value_iter), None)
}
};

let (start_repeat, start_iter) = match start.len() {
1 => {
let start_repeat = start.get(0).unwrap();
let start_repeat: usize = NumCast::from(start_repeat).ok_or_else(|| {
DaftError::ComputeError(format!(
"Error in substr: failed to cast start as usize {start_repeat}"
))
})?;
let start_repeat: Result<Option<usize>, ()> = Ok(Some(start_repeat));
let start_repeat = iter::repeat(start_repeat).take(expected_size);
(Some(start_repeat), None)
}
_ => {
let start_iter = start.as_arrow().iter().map(|s| match s {
Some(s) => {
let s: usize = NumCast::from(*s).ok_or_else(|| {
DaftError::ComputeError(format!(
"Error in repeat: failed to cast length as usize {s}"
))
})?;
let result: Result<Option<usize>, DaftError> = Ok(Some(s));
result
}
None => Ok(None),
});
(None, Some(start_iter))
}
};

match (start_iter, start_repeat, length_iter, length_repeat) {
(Some(start_iter), None, Some(length_iter), None) => {
substr_compute_result(name, self_iter, start_iter, length_iter)
}
(Some(start_iter), None, None, Some(length_repeat)) => {
substr_compute_result(name, self_iter, start_iter, length_repeat)
}
(None, Some(start_repeat), Some(length_iter), None) => {
substr_compute_result(name, self_iter, start_repeat, length_iter)
}
(None, Some(start_repeat), None, Some(length_repeat)) => {
substr_compute_result(name, self_iter, start_repeat, length_repeat)
}

_ => Err(DaftError::ComputeError(
"Start and length parameters are empty".to_string(),
)),
}
}

pub fn pad<I>(
&self,
length: &DataArray<I>,
Expand Down
7 changes: 7 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,13 @@ impl PySeries {
Ok(self.series.utf8_ilike(&pattern.series)?.into())
}

pub fn utf8_substr(&self, start: &Self, length: &Self) -> PyResult<Self> {
Ok(self
.series
.utf8_substr(&start.series, &length.series)?
.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
33 changes: 30 additions & 3 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::array::ops::PadPlacement;
use crate::series::Series;
use common_error::{DaftError, DaftResult};

use crate::series::array_impl::IntoSeries;
use crate::series::Series;
use crate::{datatypes::*, with_match_integer_daft_types};
use common_error::{DaftError, DaftResult};

impl Series {
fn with_utf8_array(&self, f: impl Fn(&Utf8Array) -> DaftResult<Series>) -> DaftResult<Series> {
Expand Down Expand Up @@ -212,4 +211,32 @@ impl Series {
pattern.with_utf8_array(|pattern_arr| Ok(arr.ilike(pattern_arr)?.into_series()))
})
}

pub fn utf8_substr(&self, start: &Series, length: &Series) -> DaftResult<Series> {
self.with_utf8_array(|arr| {
if start.data_type().is_integer() {
with_match_integer_daft_types!(start.data_type(), |$T| {
if length.data_type().is_integer() {
with_match_integer_daft_types!(length.data_type(), |$U| {
Ok(arr.substr(start.downcast::<<$T as DaftDataType>::ArrayType>()?, Some(length.downcast::<<$U as DaftDataType>::ArrayType>()?))?.into_series())
})
} else if length.data_type().is_null() {
Ok(arr.substr(start.downcast::<<$T as DaftDataType>::ArrayType>()?, None::<&DataArray<Int8Type>>)?.into_series())
} else {
Err(DaftError::TypeError(format!(
"Substr not implemented for length type {}",
length.data_type()
)))
}
})
} else if start.data_type().is_null() {
Ok(self.clone())
} else {
Err(DaftError::TypeError(format!(
"Substr not implemented for start type {}",
start.data_type()
)))
}
})
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod rpad;
mod rstrip;
mod split;
mod startswith;
mod substr;
mod upper;

use capitalize::CapitalizeEvaluator;
Expand All @@ -44,6 +45,7 @@ use rstrip::RstripEvaluator;
use serde::{Deserialize, Serialize};
use split::SplitEvaluator;
use startswith::StartswithEvaluator;
use substr::SubstrEvaluator;
use upper::UpperEvaluator;

use crate::{functions::utf8::match_::MatchEvaluator, Expr, ExprRef};
Expand Down Expand Up @@ -75,6 +77,7 @@ pub enum Utf8Expr {
Repeat,
Like,
Ilike,
Substr,
}

impl Utf8Expr {
Expand Down Expand Up @@ -105,6 +108,7 @@ impl Utf8Expr {
Repeat => &RepeatEvaluator {},
Like => &LikeEvaluator {},
Ilike => &IlikeEvaluator {},
Substr => &SubstrEvaluator {},
}
}
}
Expand Down Expand Up @@ -292,3 +296,11 @@ pub fn ilike(data: ExprRef, pattern: ExprRef) -> ExprRef {
}
.into()
}

pub fn substr(data: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Substr),
inputs: vec![data, start, length],
}
.into()
}
Loading

0 comments on commit 9a9e52e

Please sign in to comment.