From 279540f04068f6286b23aa945acc1854007d788e Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 7 Feb 2024 19:08:58 +0800 Subject: [PATCH] feat: Implements `arr.shift` (#14298) --- .../src/chunked_array/array/iterator.rs | 25 ++++++++++ .../src/chunked_array/array/namespace.rs | 33 +++++++++++++ crates/polars-plan/src/dsl/array.rs | 10 ++++ .../src/dsl/function_expr/array.rs | 11 +++++ .../source/reference/expressions/array.rst | 1 + .../docs/source/reference/series/array.rst | 3 +- py-polars/polars/expr/array.py | 49 +++++++++++++++++++ py-polars/polars/series/array.py | 39 +++++++++++++++ py-polars/src/expr/array.rs | 4 ++ .../tests/unit/namespaces/array/test_array.py | 19 +++++++ 10 files changed, 193 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index f31b2bbf86e8..37c8518bbce7 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -123,6 +123,31 @@ impl ArrayChunked { .collect_ca_with_dtype(self.name(), self.dtype().clone()) } + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + /// # Safety + // Return series of `F` must has the same dtype and number of elements as input series. + #[must_use] + pub unsafe fn zip_and_apply_amortized_same_type<'a, T, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> Self + where + T: PolarsDataType, + F: FnMut(Option>, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + self.amortized_iter() + .zip(ca.iter()) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + out.map(|s| to_arr(&s)) + }) + .collect_ca_with_dtype(self.name(), self.dtype().clone()) + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index d6279e790e35..49c30cd00e0a 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -129,6 +129,39 @@ pub trait ArrayNameSpace: AsArray { let ca = self.as_array(); array_count_matches(ca, element) } + + fn array_shift(&self, n: &Series) -> PolarsResult { + let ca = self.as_array(); + let n_s = n.cast(&DataType::Int64)?; + let n = n_s.i64()?; + let out = match n.len() { + 1 => { + if let Some(n) = n.get(0) { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { ca.apply_amortized_same_type(|s| s.as_ref().shift(n)) } + } else { + ArrayChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + ca.width(), + ) + } + }, + _ => { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { + ca.zip_and_apply_amortized_same_type(n, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(n)) => Some(s.as_ref().shift(n)), + _ => None, + } + }) + } + }, + }; + Ok(out.into_series()) + } } impl ArrayNameSpace for ArrayChunked {} diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 36609bbbff3a..1e73613c7d04 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -177,4 +177,14 @@ impl ArrayNameSpace { ) .with_fmt("arr.to_struct") } + + /// Shift every sub-array. + pub fn shift(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ArrayExpr(ArrayFunction::Shift), + &[n], + false, + false, + ) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 528faa86d2ca..77b8ac2f68e3 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -28,6 +28,7 @@ pub enum ArrayFunction { Contains, #[cfg(feature = "array_count")] CountMatches, + Shift, } impl ArrayFunction { @@ -52,6 +53,7 @@ impl ArrayFunction { Contains => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "array_count")] CountMatches => mapper.with_dtype(IDX_DTYPE), + Shift => mapper.with_same_dtype(), } } } @@ -90,6 +92,7 @@ impl Display for ArrayFunction { Contains => "contains", #[cfg(feature = "array_count")] CountMatches => "count_matches", + Shift => "shift", }; write!(f, "arr.{name}") } @@ -121,6 +124,7 @@ impl From for SpecialEq> { Contains => map_as_slice!(contains), #[cfg(feature = "array_count")] CountMatches => map_as_slice!(count_matches), + Shift => map_as_slice!(shift), } } } @@ -224,3 +228,10 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult { let ca = s.array()?; ca.array_count_matches(element.get(0).unwrap()) } + +pub(super) fn shift(s: &[Series]) -> PolarsResult { + let ca = s[0].array()?; + let n = &s[1]; + + ca.array_shift(n) +} diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst index 982a5346d269..dd3d7be45d98 100644 --- a/py-polars/docs/source/reference/expressions/array.rst +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -31,3 +31,4 @@ The following methods are available under the `expr.arr` attribute. Expr.arr.contains Expr.arr.count_matches Expr.arr.to_struct + Expr.arr.shift diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst index 12fb64f0e00d..13f2da759833 100644 --- a/py-polars/docs/source/reference/series/array.rst +++ b/py-polars/docs/source/reference/series/array.rst @@ -30,4 +30,5 @@ The following methods are available under the `Series.arr` attribute. Series.arr.explode Series.arr.contains Series.arr.count_matches - Series.arr.to_struct \ No newline at end of file + Series.arr.to_struct + Series.arr.shift \ No newline at end of file diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 45b1b7aa5fac..b228b7b562b7 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -701,3 +701,52 @@ def to_struct( else: pyexpr = self._pyexpr.arr_to_struct(fields) return wrap_expr(pyexpr) + + def shift(self, n: int | IntoExprColumn = 1) -> Expr: + """ + Shift array values by the given number of indices. + + Parameters + ---------- + n + Number of indices to shift forward. If a negative value is passed, values + are shifted in the opposite direction instead. + + Notes + ----- + This method is similar to the `LAG` operation in SQL when the value for `n` + is positive. With a negative value for `n`, it is similar to `LEAD`. + + Examples + -------- + By default, array values are shifted forward by one index. + + >>> df = pl.DataFrame( + ... {"a": [[1, 2, 3], [4, 5, 6]]}, schema={"a": pl.Array(pl.Int64, 3)} + ... ) + >>> df.with_columns(shift=pl.col("a").arr.shift()) + shape: (2, 2) + ┌───────────────┬───────────────┐ + │ a ┆ shift │ + │ --- ┆ --- │ + │ array[i64, 3] ┆ array[i64, 3] │ + ╞═══════════════╪═══════════════╡ + │ [1, 2, 3] ┆ [null, 1, 2] │ + │ [4, 5, 6] ┆ [null, 4, 5] │ + └───────────────┴───────────────┘ + + Pass a negative value to shift in the opposite direction instead. + + >>> df.with_columns(shift=pl.col("a").arr.shift(-2)) + shape: (2, 2) + ┌───────────────┬─────────────────┐ + │ a ┆ shift │ + │ --- ┆ --- │ + │ array[i64, 3] ┆ array[i64, 3] │ + ╞═══════════════╪═════════════════╡ + │ [1, 2, 3] ┆ [3, null, null] │ + │ [4, 5, 6] ┆ [6, null, null] │ + └───────────────┴─────────────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.arr_shift(n)) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 6249e82f798e..4a547485f962 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -562,3 +562,42 @@ def to_struct( """ s = wrap_s(self._s) return s.to_frame().select(F.col(s.name).arr.to_struct(fields)).to_series() + + def shift(self, n: int | IntoExprColumn = 1) -> Series: + """ + Shift array values by the given number of indices. + + Parameters + ---------- + n + Number of indices to shift forward. If a negative value is passed, values + are shifted in the opposite direction instead. + + Notes + ----- + This method is similar to the `LAG` operation in SQL when the value for `n` + is positive. With a negative value for `n`, it is similar to `LEAD`. + + Examples + -------- + By default, array values are shifted forward by one index. + + >>> s = pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)) + >>> s.arr.shift() + shape: (2,) + Series: '' [array[i64, 3]] + [ + [null, 1, 2] + [null, 4, 5] + ] + + Pass a negative value to shift in the opposite direction instead. + + >>> s.arr.shift(-2) + shape: (2,) + Series: '' [array[i64, 3]] + [ + [3, null, null] + [6, null, null] + ] + """ diff --git a/py-polars/src/expr/array.rs b/py-polars/src/expr/array.rs index 21e9501b1e9b..5b0cb2bf365b 100644 --- a/py-polars/src/expr/array.rs +++ b/py-polars/src/expr/array.rs @@ -112,4 +112,8 @@ impl PyExpr { Ok(self.inner.clone().arr().to_struct(name_gen).into()) } + + fn arr_shift(&self, n: PyExpr) -> Self { + self.inner.clone().arr().shift(n.inner).into() + } } diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/namespaces/array/test_array.py index 8766368ce1de..4486b90eeddf 100644 --- a/py-polars/tests/unit/namespaces/array/test_array.py +++ b/py-polars/tests/unit/namespaces/array/test_array.py @@ -347,3 +347,22 @@ def test_array_to_struct() -> None: assert df.lazy().select(pl.col("a").arr.to_struct()).unnest( "a" ).sum().collect().columns == ["field_0", "field_1", "field_2"] + + +def test_array_shift() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], None, [4, 5, 6], [7, 8, 9]], "n": [None, 1, 1, -2]}, + schema={"a": pl.Array(pl.Int64, 3), "n": pl.Int64}, + ) + + out = df.select( + lit=pl.col("a").arr.shift(1), expr=pl.col("a").arr.shift(pl.col("n")) + ) + expected = pl.DataFrame( + { + "lit": [[None, 1, 2], None, [None, 4, 5], [None, 7, 8]], + "expr": [None, None, [None, 4, 5], [9, None, None]], + }, + schema={"lit": pl.Array(pl.Int64, 3), "expr": pl.Array(pl.Int64, 3)}, + ) + assert_frame_equal(out, expected)