Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add str.head and str.tail #14425

Merged
merged 13 commits into from
Apr 13, 2024
23 changes: 23 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,29 @@ pub trait StringNameSpaceImpl: AsString {

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}

/// Slice the first `n` values of the string.
///
/// Determines a substring starting at the beginning of the string up to offset `n` of each
/// element in `array`. `n` can be negative, in which case the slice ends `n` characters from
/// the end of the string.
fn str_head(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::head(ca, n.i64()?))
}

/// Slice the last `n` values of the string.
///
/// Determines a substring starting at offset `n` of each element in `array`. `n` can be
/// negative, in which case the slice begins `n` characters from the end of the string.
fn str_tail(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::tail(ca, n.i64()?))
}
}

impl StringNameSpaceImpl for StringChunked {}
75 changes: 75 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) {
let str_len = str_val.len() as i64;
if n >= str_len {
Some(str_val)
} else if (n == 0) | (str_len == 0) | (n <= -str_len) {
Some("")
} else {
if n < 0 {
// If `n` is negative, it counts from the end of the string.
n += str_len; // adding negative value
}
Some(&str_val[0..n as usize])
}
} else {
None
}
}

fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) {
let str_len = str_val.len() as i64;
if n >= str_len {
Some(str_val)
} else if (n == 0) | (str_len == 0) | (n <= -str_len) {
Some("")
} else {
// We re-assign `n` to be the start of the slice.
// The end of the slice is always the end of the string.
if n < 0 {
// If `n` is negative, we count from the beginning.
n = -n;
} else {
// If `n` is positive, we count from the end.
n = str_len - n;
}
Some(&str_val[n as usize..str_len as usize])
}
} else {
None
}
}

fn substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
Expand Down Expand Up @@ -115,3 +158,35 @@ pub(super) fn substring(
_ => ternary_elementwise(ca, offset, length, substring_ternary),
}
}

pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: index `0` is in bound.
stinodego marked this conversation as resolved.
Show resolved Hide resolved
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be changed, string head/tail must be defined in terms of codepoints, not bytes! Otherwise you get illegal UTF-8 and general nonsense. Please change this and add a test-case that tests this, for example:

import polars as pl

df = pl.DataFrame({"s": ["你好世界"]})
head = pl.DataFrame({"s": ["你好"]})
tail = pl.DataFrame({"s": ["世界"]})
assert_frame_equal(df.select(pl.col.s.str.head(2)), head)
assert_frame_equal(df.select(pl.col.s.str.tail(2)), tail)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orlp, may have to make a few changes.

Copy link
Contributor Author

@mcrumiller mcrumiller Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I am noticing that str.slice does not properly index codepoints with negative indexes. Using your example:

s = pl.Series(["你好世界"])
tail = "界"
s.str.slice(-1)  # should be equivalent to "tail"
# shape: (1,)
# Series: '' [str]
# [
#         ""
# ]

I'll see if I can address this as a separate issue once I have finished with this one. Edit: opened #15136.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orlp the new impl respects code points instead of bytes. I added some specific code point tests using your example. Let me know if anything looks off to you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now.

},
(1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, head_binary),
}
}

pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: index `0` is in bound.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, tail_binary),
}
}
40 changes: 36 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub enum StringFunction {
fill_char: char,
},
Slice,
Head,
Tail,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -166,7 +168,7 @@ impl StringFunction {
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
| StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -210,6 +212,8 @@ impl Display for StringFunction {
ToInteger { .. } => "to_integer",
#[cfg(feature = "regex")]
Find { .. } => "find",
Head { .. } => "head",
Tail { .. } => "tail",
#[cfg(feature = "extract_jsonpath")]
JsonDecode { .. } => "json_decode",
LenBytes => "len_bytes",
Expand Down Expand Up @@ -345,6 +349,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "string_to_integer")]
ToInteger(base, strict) => map!(strings::to_integer, base, strict),
Slice => map_as_slice!(strings::str_slice),
Head => map_as_slice!(strings::str_head),
Tail => map_as_slice!(strings::str_tail),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -892,24 +898,50 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult<Se
let ca = s.str()?;
ca.to_integer(base, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {

fn _ensure_lengths(s: &[Series]) -> bool {
// Calculate the post-broadcast length and ensure everything is consistent.
let len = s
.iter()
.map(|series| series.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
s.iter()
.all(|series| series.len() == 1 || series.len() == len)
}

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
s.iter().all(|series| series.len() == 1 || series.len() == len),
ComputeError: "all series in `str_slice` should have equal or unit length"
_ensure_lengths(s),
ComputeError: "all series in `str_slice` should have equal or unit length",
);
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

pub(super) fn str_head(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_head` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_head(n)?.into_series())
}

pub(super) fn str_tail(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_tail` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_tail(n)?.into_series())
}

#[cfg(feature = "string_encoding")]
pub(super) fn hex_encode(s: &Series) -> PolarsResult<Series> {
Ok(s.str()?.hex_encode().into_series())
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,26 @@ impl StringNameSpace {
)
}

/// Take the first `n` characters of the string values.
pub fn head(self, n: Expr) -> Expr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can't just define these in terms of a slice operation - that would save a lot of code bloat. But that might not work with negative indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @stinodego -- this was my initial intent, when I said I would piggyback on @reswqa's implementation of str.slice. I realized soon after that the negative indexing for head requires calculation of the string length to determine the end of the slice.

Here are the operations and their slice equivalents:

s.str.head(3)   # s.str.slice(3, None)
s.str.head(-3)  # no equivalent: must know string length for start offset

s.str.tail(3)   # s.str.slice(-3, None)
s.str.tail(-3)  # s.str.slice(3, None)

So for tail, we could do it because slice can run to the end of the string by itself, but for head we have no recourse. I suppose this would save us a little bit of bloat but it does make the code a bit asymmetric, but on the other hand the tail implementation is a bit more performant than slice because we have one fewer parameters, and so we have more fast paths. So it's a tradeoff here. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can use slice in combination with len_chars, but clearly it will be a bit more efficient to have a dedicated implementation like in this PR. I'll leave it to Ritchie to be the judge here.

self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Head),
&[n],
false,
false,
)
}

/// Take the last `n` characters of the string values.
pub fn tail(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Tail),
&[n],
false,
false,
)
}

pub fn explode(self) -> Expr {
self.0
.apply_private(FunctionExpr::StringExpr(StringFunction::Explode))
Expand Down
76 changes: 76 additions & 0 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,82 @@ def slice(
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def head(self, n: int | IntoExprColumn = 10) -> Expr:
"""
Return the first n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice. Negative indexing supported.

Returns
-------
Expr
Expression of data type :class:`Utf8`.

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.head(3).alias("s_head3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_head3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ pea │
│ null ┆ null │
│ papaya ┆ pap │
│ dragonfruit ┆ dra │
└─────────────┴─────────┘
"""
n = parse_as_expression(n)
return wrap_expr(self._pyexpr.str_head(n))

def tail(self, n: int | IntoExprColumn = 10) -> Expr:
"""
Return the last n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice. Negative indexing is supported.

Returns
-------
Expr
Expression of data type :class:`Utf8`.

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.tail(3).alias("s_tail3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_tail3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ ear │
│ null ┆ null │
│ papaya ┆ aya │
│ dragonfruit ┆ uit │
└─────────────┴─────────┘
"""
n = parse_as_expression(n)
return wrap_expr(self._pyexpr.str_tail(n))

def explode(self) -> Expr:
"""
Returns a column with a separate row for every string character.
Expand Down
66 changes: 66 additions & 0 deletions py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,72 @@ def slice(
]
"""

def head(self, n: int | IntoExprColumn = 10) -> Series:
"""
Return the first n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice

Returns
-------
Series
Series of data type :class:`Struct` with fields of data type :class:`Utf8`.
mcrumiller marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.
mcrumiller marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"])
>>> s.str.head(3)
shape: (4,)
Series: 's' [str]
[
"pea"
null
"pap"
"dra"
]
"""

def tail(self, n: int | IntoExprColumn = 10) -> Series:
mcrumiller marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the last n characters of each string in a Utf8 Series.

Parameters
----------
n
Length of the slice

Returns
-------
Series
Series of data type :class:`Struct` with fields of data type :class:`Utf8`.

Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.

Examples
--------
>>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"])
>>> s.str.tail(3)
shape: (4,)
Series: 's' [str]
[
"ear"
null
"aya"
"uit"
]
"""

def explode(self) -> Series:
"""
Returns a column with a separate row for every string character.
Expand Down
Loading
Loading