Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Infer reshape dims when determining schema #18923

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod any_value;
mod dtype;
mod field;
mod into_scalar;
mod reshape;
#[cfg(feature = "object")]
mod static_array_collect;
mod time_unit;
Expand Down Expand Up @@ -41,6 +42,7 @@ use polars_utils::abs_diff::AbsDiff;
use polars_utils::float::IsFloat;
use polars_utils::min_max::MinMax;
use polars_utils::nulls::IsNull;
pub use reshape::*;
#[cfg(feature = "serde")]
use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor};
#[cfg(any(feature = "serde", feature = "serde-lazy"))]
Expand Down
118 changes: 118 additions & 0 deletions crates/polars-core/src/datatypes/reshape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::fmt;
use std::hash::Hash;
use std::num::NonZeroU64;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct Dimension(NonZeroU64);

/// A dimension in a reshape.
///
/// Any dimension smaller than 0 is seen as an `infer`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ReshapeDimension {
Infer,
Specified(Dimension),
}

impl fmt::Debug for Dimension {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}

impl fmt::Display for ReshapeDimension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Infer => f.write_str("inferred"),
Self::Specified(v) => v.get().fmt(f),
}
}
}

impl Hash for ReshapeDimension {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.to_repr().hash(state)
}
}

impl Dimension {
#[inline]
pub const fn new(v: u64) -> Self {
assert!(v <= i64::MAX as u64);

// SAFETY: Bounds check done before
let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) };
Self(dim)
}

#[inline]
pub const fn get(self) -> u64 {
self.0.get() - 1
}
}

impl ReshapeDimension {
#[inline]
pub const fn new(v: i64) -> Self {
if v < 0 {
Self::Infer
} else {
// SAFETY: We have bounds checked for -1
let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) };
Self::Specified(Dimension(dim))
}
}

#[inline]
fn to_repr(self) -> u64 {
match self {
Self::Infer => 0,
Self::Specified(dim) => dim.0.get(),
}
}

#[inline]
pub const fn get(self) -> Option<u64> {
match self {
ReshapeDimension::Infer => None,
ReshapeDimension::Specified(dim) => Some(dim.get()),
}
}

#[inline]
pub const fn get_or_infer(self, inferred: u64) -> u64 {
match self {
ReshapeDimension::Infer => inferred,
ReshapeDimension::Specified(dim) => dim.get(),
}
}

#[inline]
pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 {
match self {
ReshapeDimension::Infer => f(),
ReshapeDimension::Specified(dim) => dim.get(),
}
}

pub const fn new_dimension(dimension: u64) -> ReshapeDimension {
Self::Specified(Dimension::new(dimension))
}
}

impl TryFrom<i64> for Dimension {
type Error = ();

#[inline]
fn try_from(value: i64) -> Result<Self, Self::Error> {
let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else {
return Err(());
};

Ok(v)
}
}
5 changes: 3 additions & 2 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use polars_utils::pl_str::PlSmallStr;
use self::gather::check_bounds_ca;
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::metadata::{MetadataFlags, MetadataTrait};
use crate::datatypes::ReshapeDimension;
use crate::prelude::*;
use crate::series::{BitRepr, IsSorted, SeriesPhysIter};
use crate::utils::{slice_offsets, Container};
Expand Down Expand Up @@ -730,15 +731,15 @@ impl Column {
self.as_materialized_series().unique().map(Column::from)
}

pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Self> {
pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Self> {
// @scalar-opt
self.as_materialized_series()
.reshape_list(dimensions)
.map(Self::from)
}

#[cfg(feature = "dtype-array")]
pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Self> {
pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Self> {
// @scalar-opt
self.as_materialized_series()
.reshape_array(dimensions)
Expand Down
10 changes: 6 additions & 4 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,18 @@ impl NumOpsDispatchInner for BooleanType {
}

#[cfg(feature = "dtype-array")]
fn array_shape(dt: &DataType, infer: bool) -> Vec<i64> {
fn inner(dt: &DataType, buf: &mut Vec<i64>) {
fn array_shape(dt: &DataType, infer: bool) -> Vec<ReshapeDimension> {
fn inner(dt: &DataType, buf: &mut Vec<ReshapeDimension>) {
if let DataType::Array(_, size) = dt {
buf.push(*size as i64)
buf.push(ReshapeDimension::Specified(
Dimension::try_from(*size as i64).unwrap(),
))
}
}

let mut buf = vec![];
if infer {
buf.push(-1)
buf.push(ReshapeDimension::Infer)
}
inner(dt, &mut buf);
buf
Expand Down
118 changes: 61 additions & 57 deletions crates/polars-core/src/series/ops/reshape.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use std::borrow::Cow;
#[cfg(feature = "dtype-array")]
use std::cmp::Ordering;
#[cfg(feature = "dtype-array")]
use std::collections::VecDeque;

use arrow::array::*;
use arrow::legacy::kernels::list::array_to_unit_list;
use arrow::offset::Offsets;
use polars_error::{polars_bail, polars_ensure, PolarsResult};
#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;

use crate::chunked_array::builder::get_list_builder;
Expand Down Expand Up @@ -90,70 +85,70 @@ impl Series {
}

#[cfg(feature = "dtype-array")]
pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Series> {
pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
polars_ensure!(
!dimensions.is_empty(),
InvalidOperation: "at least one dimension must be specified"
);

let mut dims = dimensions.iter().copied().collect::<VecDeque<_>>();

let leaf_array = self.get_leaf_array();
let size = leaf_array.len();

let mut total_dim_size = 1;
let mut infer_dim_index: Option<usize> = None;
for (index, &dim) in dims.iter().enumerate() {
match dim.cmp(&0) {
Ordering::Greater => total_dim_size *= dim as usize,
Ordering::Equal => {
let mut num_infers = 0;
for (index, &dim) in dimensions.iter().enumerate() {
match dim {
ReshapeDimension::Infer => {
polars_ensure!(
index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dims)
num_infers == 0,
InvalidOperation: "can only specify one inferred dimension"
);
total_dim_size = 0;
// We can early exit here, as empty arrays will error with multiple dimensions,
// and non-empty arrays will error when the first dimension is zero.
break;
num_infers += 1;
},
Ordering::Less => {
polars_ensure!(
infer_dim_index.is_none(),
InvalidOperation: "can only specify one unknown dimension"
);
infer_dim_index = Some(index);
ReshapeDimension::Specified(dim) => {
let dim = dim.get();

if dim > 0 {
total_dim_size *= dim as usize
} else {
polars_ensure!(
index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dimensions)
);
total_dim_size = 0;
// We can early exit here, as empty arrays will error with multiple dimensions,
// and non-empty arrays will error when the first dimension is zero.
break;
}
},
}
}

if size == 0 {
if dims.len() > 1 || (infer_dim_index.is_none() && total_dim_size != 0) {
polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dims))
if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) {
polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions))
}
} else if total_dim_size == 0 {
polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dims))
polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions))
} else {
polars_ensure!(
size % total_dim_size == 0,
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dims)
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
);
}

// Infer dimension
if let Some(index) = infer_dim_index {
let inferred_dim = size / total_dim_size;
let item = dims.get_mut(index).unwrap();
*item = i64::try_from(inferred_dim).unwrap();
}

let leaf_array = leaf_array.rechunk();
let mut prev_dtype = leaf_array.dtype().clone();
let mut prev_array = leaf_array.chunks()[0].clone();

// We pop the outer dimension as that is the height of the series.
let _ = dims.pop_front();
while let Some(dim) = dims.pop_back() {
for idx in (1..dimensions.len()).rev() {
// Infer dimension if needed
let dim = dimensions[idx].get_or_infer_with(|| {
debug_assert!(num_infers > 0);
(size / total_dim_size) as u64
});
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);

prev_array = FixedSizeListArray::new(
Expand All @@ -172,7 +167,7 @@ impl Series {
})
}

pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Series> {
pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
polars_ensure!(
!dimensions.is_empty(),
InvalidOperation: "at least one dimension must be specified"
Expand All @@ -187,38 +182,43 @@ impl Series {

let s_ref = s.as_ref();

let dimensions = dimensions.to_vec();
// let dimensions = dimensions.to_vec();

match dimensions.len() {
1 => {
polars_ensure!(
dimensions[0] as usize == s_ref.len() || dimensions[0] == -1_i64,
dimensions[0].get().map_or(true, |dim| dim as usize == s_ref.len()),
InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
);
Ok(s_ref.clone())
},
2 => {
let mut rows = dimensions[0];
let mut cols = dimensions[1];
let rows = dimensions[0];
let cols = dimensions[1];

if s_ref.len() == 0_usize {
if (rows == -1 || rows == 0) && (cols == -1 || cols == 0 || cols == 1) {
if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
let s = reshape_fast_path(s.name().clone(), s_ref);
return Ok(s);
} else {
polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {:?}", dimensions,)
polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
}
}

use ReshapeDimension as RD;
// Infer dimension.
if rows == -1 && cols >= 1 {
rows = s_ref.len() as i64 / cols
} else if cols == -1 && rows >= 1 {
cols = s_ref.len() as i64 / rows
} else if rows == -1 && cols == -1 {
rows = s_ref.len() as i64;
cols = 1_i64;
}

let (rows, cols) = match (rows, cols) {
(RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
(s_ref.len() as u64 / cols.get(), cols.get())
},
(RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
(rows.get(), s_ref.len() as u64 / rows.get())
},
(RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
(RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
_ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
};

// Fast path, we can create a unit list so we only allocate offsets.
if rows as usize == s_ref.len() && cols == 1 {
Expand All @@ -234,9 +234,9 @@ impl Series {
let mut builder =
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone())?;

let mut offset = 0i64;
let mut offset = 0u64;
for _ in 0..rows {
let row = s_ref.slice(offset, cols as usize);
let row = s_ref.slice(offset as i64, cols as usize);
builder.append_series(&row).unwrap();
offset += cols;
}
Expand Down Expand Up @@ -279,7 +279,11 @@ mod test {
(&[-1, 2], 2),
(&[2, -1], 2),
] {
let out = s.reshape_list(dims)?;
let dims = dims
.iter()
.map(|&v| ReshapeDimension::new(v))
.collect::<Vec<_>>();
let out = s.reshape_list(&dims)?;
assert_eq!(out.len(), list_len);
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.explode()?.len(), 4);
Expand Down
Loading
Loading