Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): add head/tail under string namespace #10339

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,52 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {

unsafe { Ok(Utf8Chunked::from_chunks(ca.name(), chunks)) }
}

/// Return the first n characters
fn str_head(&self, n: i64) -> PolarsResult<Utf8Chunked> {
let ca = self.as_utf8();

// if n is negative, we return all but the last abs(n) characters
let chunks = if n < 0 {
let abs_n = n.abs() as u64;
ca.downcast_iter()
.map(|c| {
// a negative n requires that we collect a different substring length
// for each item.
polars_arrow::export::arrow::array::Utf8Array::from_iter_values(c.iter().map(
|s| {
match s {
Some(s) => {
// saturating_sub prevents length < 0
let s_len = (s.len() as u64).saturating_sub(abs_n);
Some(&s[0..s_len])
}
None => s,
}
},
))
})
.collect::<arrow::error::Result<_>>()?
} else {
let n = n as u64;
ca.downcast_iter()
.map(|c| substring(c, 0, &Some(n as u64)))
.collect::<arrow::error::Result<_>>()?
};

unsafe { Ok(Utf8Chunked::from_chunks(ca.name(), chunks)) }
}

/// Return the last n characters
fn str_tail(&self, n: i64) -> PolarsResult<Utf8Chunked> {
let ca = self.as_utf8();
let chunks = ca
.downcast_iter()
.map(|c| substring(c, -n, &None))
.collect::<arrow::error::Result<_>>()?;

unsafe { Ok(Utf8Chunked::from_chunks(ca.name(), chunks)) }
}
}

impl Utf8NameSpaceImpl for Utf8Chunked {}
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "string_from_radix")]
FromRadix(radix, strict) => map!(strings::from_radix, radix, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Head(n) => map!(strings::str_head, n),
Tail(n) => map!(strings::str_tail, n),
Explode => map!(strings::explode),
#[cfg(feature = "dtype-decimal")]
ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
Expand Down
25 changes: 22 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ pub enum StringFunction {
},
RStrip(Option<String>),
Slice(i64, Option<u64>),
Head(i64),
Tail(i64),
StartsWith,
Strip(Option<String>),
#[cfg(feature = "temporal")]
Expand Down Expand Up @@ -111,9 +113,14 @@ impl StringFunction {
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => {
mapper.with_same_dtype()
}
Uppercase
| Lowercase
| Strip(_)
| LStrip(_)
| RStrip(_)
| Slice(_, _)
| Head(_)
| Tail(_) => mapper.with_same_dtype(),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
}
Expand Down Expand Up @@ -152,6 +159,8 @@ impl Display for StringFunction {
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => "replace",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::Head(_) => "str_head",
StringFunction::Tail(_) => "str_tail",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::Strip(_) => "strip",
#[cfg(feature = "temporal")]
Expand Down Expand Up @@ -724,6 +733,16 @@ pub(super) fn str_slice(s: &Series, start: i64, length: Option<u64>) -> PolarsRe
ca.str_slice(start, length).map(|ca| ca.into_series())
}

pub(super) fn str_head(s: &Series, n: i64) -> PolarsResult<Series> {
let ca = s.utf8()?;
ca.str_head(n).map(|ca| ca.into_series())
}

pub(super) fn str_tail(s: &Series, n: i64) -> PolarsResult<Series> {
let ca = s.utf8()?;
ca.str_tail(n).map(|ca| ca.into_series())
}

pub(super) fn explode(s: &Series) -> PolarsResult<Series> {
let ca = s.utf8()?;
ca.explode()
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,18 @@ impl StringNameSpace {
)))
}

/// Return the first n characters in the string
pub fn str_head(self, n: i64) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Head(n)))
}

/// Return the last n characters in the string
pub fn str_tail(self, n: i64) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Tail(n)))
}

pub fn explode(self) -> Expr {
self.0
.apply_private(FunctionExpr::StringExpr(StringFunction::Explode))
Expand Down
66 changes: 66 additions & 0 deletions py-polars/debug/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import re
import sys
import time
from pathlib import Path


def launch_debugging() -> None:
"""
Debug Rust files via Python.

Determine the pID for the current debugging session, attach the Rust LLDB launcher,
and execute the originally-requested script.
"""
if len(sys.argv) == 1:
raise RuntimeError(
"launch.py is not meant to be executed directly; please use the `Python: "
"Debug Rust` debugging configuration to run a python script that uses the "
"polars library."
)

# get the current process ID
pID = os.getpid()

# print to the console to allow the "Rust LLDB" routine to pick up on the signal
launch_file = Path(__file__).parents[2] / ".vscode/launch.json"
if not launch_file.exists():
raise RuntimeError(f"Cannot locate {launch_file}")
with launch_file.open("r") as f:
launch_info = f.read()

# overwrite the pid found in launch.config with the pid for the current process
# match initial the "Rust LLDB" definition with the pid immediately after
pattern = re.compile('("Rust LLDB",\\s*"pid":\\s*")\\d+(")')
found = pattern.search(launch_info)
if not found:
raise RuntimeError(
"Cannot locate pid definition in launch.json for Rust LLDB configuration. "
"Please follow the instructions in CONTRIBUTING.md for creating the "
"launch configuration."
)

launch_info_with_new_pid = pattern.sub(rf"\g<1>{pID}\g<2>", launch_info)
with launch_file.open("w") as f:
f.write(launch_info_with_new_pid)

# print pID to the debug console. This auto-triggers the Rust LLDB configurations.
print(f"pID = {pID}")

# give the LLDB time to connect. We may have to play with this setting.
time.sleep(1)

# run the originally requested file
# update sys.argv so that when exec() is called, it's populated with the requested
# script name in sys.argv[0], and the remaining args after
sys.argv.pop(0)
with Path(sys.argv[0]).open() as fh:
script_contents = fh.read()

# path to the script to be executed
fh = Path(sys.argv[0])
exec(compile(script_contents, fh, mode="exec"), {"__name__": "__main__"})


if __name__ == "__main__":
launch_debugging()
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expressions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.extract
Expr.str.extract_all
Expr.str.extract_groups
Expr.str.head
Expr.str.json_extract
Expr.str.json_path_match
Expr.str.lengths
Expand All @@ -36,6 +37,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.starts_with
Expr.str.strip
Expr.str.strptime
Expr.str.tail
Expr.str.to_date
Expr.str.to_datetime
Expr.str.to_decimal
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/series/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.extract
Series.str.extract_all
Series.str.extract_groups
Series.str.head
Series.str.json_extract
Series.str.json_path_match
Series.str.lengths
Expand All @@ -36,6 +37,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.starts_with
Series.str.strip
Series.str.strptime
Series.str.tail
Series.str.to_date
Series.str.to_datetime
Series.str.to_decimal
Expand Down
83 changes: 81 additions & 2 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,14 +1681,19 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
offset
Start index. Negative indexing is supported.
length
Length of the slice. If set to ``None`` (default), the slice is taken to the
end of the string.
Length in characters of the slice. If set to ``None`` (default), the slice
is taken to the end of the string.

Returns
-------
Expr
Expression of data type :class:`Utf8`.

Notes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added this note to both the str and expr docstrings for slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to be explicit in the definition as per @orlp 's comment here, but I wonder if we should remove the "non-surrogate" part, since not many people will recognize what that means.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without it it's technically not correct. You can also call it a Unicode Scalar Value like the Rust docs if you prefer.

-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
Expand Down Expand Up @@ -1727,6 +1732,80 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
"""
return wrap_expr(self._pyexpr.str_slice(offset, length))

def head(self, n: int) -> Expr:
"""
Return the first n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice

Returns
-------
Expr
Expression of data type :class:`Utf8`.

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.head(3).alias("s_head3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_head3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ pea │
│ null ┆ null │
│ papaya ┆ pap │
│ dragonfruit ┆ dra │
└─────────────┴─────────┘
"""
return wrap_expr(self._pyexpr.str_head(n))

def tail(self, n: int) -> Expr:
"""
Return the last n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice

Returns
-------
Expr
Expression of data type :class:`Utf8`.

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.tail(3).alias("s_tail3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_tail3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ ear │
│ null ┆ null │
│ papaya ┆ aya │
│ dragonfruit ┆ uit │
└─────────────┴─────────┘
"""
return wrap_expr(self._pyexpr.str_tail(n))

def explode(self) -> Expr:
"""
Returns a column with a separate row for every string character.
Expand Down
Loading
Loading