Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add str.head and str.tail #14425

Merged
merged 13 commits into from
Apr 13, 2024
23 changes: 23 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,29 @@ pub trait StringNameSpaceImpl: AsString {

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}

/// Slice the first `n` values of the string.
///
/// Determines a substring starting at the beginning of the string up to offset `n` of each
/// element in `array`. `n` can be negative, in which case the slice ends `n` characters from
/// the end of the string.
fn str_head(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::head(ca, n.i64()?))
}

/// Slice the last `n` values of the string.
///
/// Determines a substring starting at offset `n` of each element in `array`. `n` can be
/// negative, in which case the slice begins `n` characters from the start of the string.
fn str_tail(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::tail(ca, n.i64()?))
}
}

impl StringNameSpaceImpl for StringChunked {}
112 changes: 105 additions & 7 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,72 @@
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
// `max_len` is guaranteed to be at least the total number of characters.
let max_len = str_val.len();
if n == 0 {
Some("")
} else {
let end_idx = if n > 0 {
if n as usize >= max_len {
return opt_str_val;
}
// End after the nth codepoint.
str_val
.char_indices()
.nth(n as usize)
.map(|(idx, _)| idx)
.unwrap_or(max_len)
} else {
// End after the nth codepoint from the end.
str_val
.char_indices()
.rev()
.nth((-n - 1) as usize)
.map(|(idx, _)| idx)
.unwrap_or(0)
};
Some(&str_val[..end_idx])
}
} else {
None
}
}

fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
// `max_len` is guaranteed to be at least the total number of characters.
let max_len = str_val.len();
if n == 0 {
Some("")
} else {
let start_idx = if n > 0 {
if n as usize >= max_len {
return opt_str_val;
}
// Start from nth codepoint from the end
str_val
.char_indices()
.rev()
.nth((n - 1) as usize)
.map(|(idx, _)| idx)
.unwrap_or(0)
} else {
// Start after the nth codepoint
str_val
.char_indices()
.nth((-n) as usize)
.map(|(idx, _)| idx)
.unwrap_or(max_len)
};
Some(&str_val[start_idx..])
}
} else {
None
}
}

fn substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
Expand Down Expand Up @@ -57,30 +123,30 @@ pub(super) fn substring(
) -> StringChunked {
match (ca.len(), offset.len(), length.len()) {
(1, 1, _) => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `offset` was verified to have at least 1 element.
let offset = unsafe { offset.get_unchecked(0) };
unary_elementwise(length, |length| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(_, 1, 1) => {
// SAFETY: index `0` is in bound.
// SAFETY: `offset` was verified to have at least 1 element.
let offset = unsafe { offset.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `length` was verified to have at least 1 element.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(ca, |str_val| substring_ternary(str_val, offset, length))
},
(1, _, 1) => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `length` was verified to have at least 1 element.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(1, len_b, len_c) if len_b == len_c => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
binary_elementwise(offset, length, |offset, length| {
substring_ternary(str_val, offset, length)
Expand Down Expand Up @@ -115,3 +181,35 @@ pub(super) fn substring(
_ => ternary_elementwise(ca, offset, length, substring_ternary),
}
}

pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: `n` was verified to have at least 1 element.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be changed, string head/tail must be defined in terms of codepoints, not bytes! Otherwise you get illegal UTF-8 and general nonsense. Please change this and add a test-case that tests this, for example:

import polars as pl

df = pl.DataFrame({"s": ["你好世界"]})
head = pl.DataFrame({"s": ["你好"]})
tail = pl.DataFrame({"s": ["世界"]})
assert_frame_equal(df.select(pl.col.s.str.head(2)), head)
assert_frame_equal(df.select(pl.col.s.str.tail(2)), tail)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orlp, may have to make a few changes.

Copy link
Contributor Author

@mcrumiller mcrumiller Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I am noticing that str.slice does not properly index codepoints with negative indexes. Using your example:

s = pl.Series(["你好世界"])
tail = "界"
s.str.slice(-1)  # should be equivalent to "tail"
# shape: (1,)
# Series: '' [str]
# [
#         ""
# ]

I'll see if I can address this as a separate issue once I have finished with this one. Edit: opened #15136.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orlp the new impl respects code points instead of bytes. I added some specific code point tests using your example. Let me know if anything looks off to you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now.

},
(1, _) => {
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, head_binary),
}
}

pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: `n` was verified to have at least 1 element.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, tail_binary),
}
}
40 changes: 36 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub enum StringFunction {
fill_char: char,
},
Slice,
Head,
Tail,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -166,7 +168,7 @@ impl StringFunction {
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
| StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -210,6 +212,8 @@ impl Display for StringFunction {
ToInteger { .. } => "to_integer",
#[cfg(feature = "regex")]
Find { .. } => "find",
Head { .. } => "head",
Tail { .. } => "tail",
#[cfg(feature = "extract_jsonpath")]
JsonDecode { .. } => "json_decode",
LenBytes => "len_bytes",
Expand Down Expand Up @@ -345,6 +349,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "string_to_integer")]
ToInteger(base, strict) => map!(strings::to_integer, base, strict),
Slice => map_as_slice!(strings::str_slice),
Head => map_as_slice!(strings::str_head),
Tail => map_as_slice!(strings::str_tail),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -892,24 +898,50 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult<Se
let ca = s.str()?;
ca.to_integer(base, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {

fn _ensure_lengths(s: &[Series]) -> bool {
// Calculate the post-broadcast length and ensure everything is consistent.
let len = s
.iter()
.map(|series| series.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
s.iter()
.all(|series| series.len() == 1 || series.len() == len)
}

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
s.iter().all(|series| series.len() == 1 || series.len() == len),
ComputeError: "all series in `str_slice` should have equal or unit length"
_ensure_lengths(s),
ComputeError: "all series in `str_slice` should have equal or unit length",
);
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

pub(super) fn str_head(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_head` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_head(n)?.into_series())
}

pub(super) fn str_tail(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_tail` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_tail(n)?.into_series())
}

#[cfg(feature = "string_encoding")]
pub(super) fn hex_encode(s: &Series) -> PolarsResult<Series> {
Ok(s.str()?.hex_encode().into_series())
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,26 @@ impl StringNameSpace {
)
}

/// Take the first `n` characters of the string values.
pub fn head(self, n: Expr) -> Expr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can't just define these in terms of a slice operation - that would save a lot of code bloat. But that might not work with negative indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @stinodego -- this was my initial intent, when I said I would piggyback on @reswqa's implementation of str.slice. I realized soon after that the negative indexing for head requires calculation of the string length to determine the end of the slice.

Here are the operations and their slice equivalents:

s.str.head(3)   # s.str.slice(3, None)
s.str.head(-3)  # no equivalent: must know string length for start offset

s.str.tail(3)   # s.str.slice(-3, None)
s.str.tail(-3)  # s.str.slice(3, None)

So for tail, we could do it because slice can run to the end of the string by itself, but for head we have no recourse. I suppose this would save us a little bit of bloat but it does make the code a bit asymmetric, but on the other hand the tail implementation is a bit more performant than slice because we have one fewer parameters, and so we have more fast paths. So it's a tradeoff here. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can use slice in combination with len_chars, but clearly it will be a bit more efficient to have a dedicated implementation like in this PR. I'll leave it to Ritchie to be the judge here.

self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Head),
&[n],
false,
false,
)
}

/// Take the last `n` characters of the string values.
pub fn tail(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Tail),
&[n],
false,
false,
)
}

pub fn explode(self) -> Expr {
self.0
.apply_private(FunctionExpr::StringExpr(StringFunction::Explode))
Expand Down
4 changes: 3 additions & 1 deletion py-polars/docs/source/reference/expressions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.extract_all
Expr.str.extract_groups
Expr.str.find
Expr.str.head
Expr.str.json_decode
Expr.str.json_extract
Expr.str.json_path_match
Expand All @@ -33,6 +34,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.n_chars
Expr.str.pad_end
Expr.str.pad_start
Expr.str.parse_int
Expr.str.replace
Expr.str.replace_all
Expr.str.replace_many
Expand All @@ -51,6 +53,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.strip_prefix
Expr.str.strip_suffix
Expr.str.strptime
Expr.str.tail
Expr.str.to_date
Expr.str.to_datetime
Expr.str.to_decimal
Expand All @@ -60,4 +63,3 @@ The following methods are available under the `expr.str` attribute.
Expr.str.to_time
Expr.str.to_uppercase
Expr.str.zfill
Expr.str.parse_int
4 changes: 3 additions & 1 deletion py-polars/docs/source/reference/series/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.extract_all
Series.str.extract_groups
Series.str.find
Series.str.head
Series.str.json_decode
Series.str.json_extract
Series.str.json_path_match
Expand All @@ -33,6 +34,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.n_chars
Series.str.pad_end
Series.str.pad_start
Series.str.parse_int
Series.str.replace
Series.str.replace_all
Series.str.replace_many
Expand All @@ -51,6 +53,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.strip_prefix
Series.str.strip_suffix
Series.str.strptime
Series.str.tail
Series.str.to_date
Series.str.to_datetime
Series.str.to_decimal
Expand All @@ -60,4 +63,3 @@ The following methods are available under the `Series.str` attribute.
Series.str.to_titlecase
Series.str.to_uppercase
Series.str.zfill
Series.str.parse_int
Loading
Loading