From da365be33b6b7fbe60e63aaf2e8b3dc0671ddbe8 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Mar 2021 18:50:51 +0200 Subject: [PATCH 1/3] TEST: Add benchmarks for Result<_, ShapeError> methods --- benches/error-handling.rs | 110 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 benches/error-handling.rs diff --git a/benches/error-handling.rs b/benches/error-handling.rs new file mode 100644 index 000000000..0c2b1136c --- /dev/null +++ b/benches/error-handling.rs @@ -0,0 +1,110 @@ +#![feature(test)] +#![allow( + clippy::many_single_char_names, + clippy::deref_addrof, + clippy::unreadable_literal, + clippy::many_single_char_names +)] +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; +use ndarray::ErrorKind; + +// Use ZST elements to remove allocation from the benchmarks + +#[derive(Copy, Clone, Debug)] +struct Zst; + +type A4 = Array4; + +#[bench] +fn from_elem(bench: &mut Bencher) { + bench.iter(|| { + A4::from_elem((1, 2, 3, 4), Zst) + }) +} + +#[bench] +fn from_shape_vec_ok(bench: &mut Bencher) { + bench.iter(|| { + let v: Vec = vec![Zst; 1 * 2 * 3 * 4]; + let x = A4::from_shape_vec((1, 2, 3, 4).strides((24, 12, 4, 1)), v); + debug_assert!(x.is_ok(), "problem with {:?}", x); + x + }) +} + +#[bench] +fn from_shape_vec_fail(bench: &mut Bencher) { + bench.iter(|| { + let v: Vec = vec![Zst; 1 * 2 * 3 * 4]; + let x = A4::from_shape_vec((1, 2, 3, 4).strides((4, 3, 2, 1)), v); + debug_assert!(x.is_err()); + x + }) +} + +#[bench] +fn into_shape_fail(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((5, 3, 2, 1)) + }) +} + +#[bench] +fn into_shape_ok_c(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((4, 3, 2, 1)) + }) +} + +#[bench] +fn into_shape_ok_f(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4).f(), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((4, 3, 2, 1)) + }) +} + +#[bench] +fn stack_ok(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter().collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(1), &rows); + debug_assert!(res.is_ok(), "err {:?}", res); + res + }); +} + +#[bench] +fn stack_err_axis(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter().collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(2), &rows); + debug_assert!(res.is_err()); + res + }); +} + +#[bench] +fn stack_err_shape(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter() + .enumerate() + .map(|(i, mut row)| { row.slice_collapse(s![..(i as isize)]); row }) + .collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(1), &rows); + debug_assert!(res.is_err()); + debug_assert_eq!(res.clone().unwrap_err().kind(), ErrorKind::IncompatibleShape); + res + }); +} From 3a59caafa91c0266d1427f3e072619133569b622 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Mar 2021 13:33:31 +0200 Subject: [PATCH 2/3] FEAT: Encode expected/actual info in ShapeError MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a (limited) way to add specific information to a ShapeError Admittedly wonky, but space efficient. Result<(), ShapeError> used to be 1 byte, and with this change it expands to 16 bytes (2 usize on 64-bit). The remaining 15 bytes are used for optimistically packing as much of extra info into the error message as possible. For example we can store expected/actual index for errors (for example index out of bounds or axis out of bounds, these are not so commonly handled with ShapeError). With this change it is supported: - Expected/actual index with 7 bytes per index - Expected/actual shape with 7 bytes per shape supports storing shapes with one or two bytes (< 256²) per dimension, with limited ndim. --- Cargo.toml | 1 + src/error.rs | 525 +++++++++++++++++++++++++++++++++++++++++++++- tests/array.rs | 8 +- tests/stacking.rs | 6 + 4 files changed, 530 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 61c26d4cd..8288e6dc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ defmac = "0.2" quickcheck = { version = "0.9", default-features = false } approx = "0.4" itertools = { version = "0.10.0", default-features = false, features = ["use_std"] } +matches = "0.1.8" [features] default = ["std"] diff --git a/src/error.rs b/src/error.rs index c45496142..177711c9b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,16 +5,26 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +#![allow(clippy::identity_op)] use super::Dimension; +use crate::itertools::enumerate; + #[cfg(feature = "std")] use std::error::Error; use std::fmt; +use std::mem::size_of; /// An error related to array shape or layout. +/// +/// The shape error encodes and shows expected/actual indices and shapes in some cases, which +/// is visible in the Display/Debug representation. Since this is done without allocation, it is +/// space-limited and bigger indices and shapes may not be representable. #[derive(Clone)] pub struct ShapeError { - // we want to be able to change this representation later + /// Error category repr: ErrorKind, + /// Additional info + info: InfoType, } impl ShapeError { @@ -24,9 +34,52 @@ impl ShapeError { self.repr } - /// Create a new `ShapeError` + /// Create a new `ShapeError` from the given kind pub fn from_kind(error: ErrorKind) -> Self { - from_kind(error) + Self::from_kind_info(error, info_default()) + } + + fn from_kind_info(repr: ErrorKind, info: InfoType) -> Self { + ShapeError { repr, info } + } + + pub(crate) fn invalid_axis(expected: usize, actual: usize) -> Self { + // TODO: OutOfBounds for compatibility reasons, should be more specific + Self::from_kind_info(ErrorKind::OutOfBounds, encode_indices(expected, actual)) + } + + pub(crate) fn shape_length_exceeds_data_length(expected: usize, actual: usize) -> Self { + // TODO: OutOfBounds for compatibility reasons, should be more specific + Self::from_kind_info(ErrorKind::OutOfBounds, encode_indices(expected, actual)) + } + + pub(crate) fn incompatible_layout(expected: ExpectedLayout) -> Self { + Self::from_kind_info(ErrorKind::IncompatibleLayout, encode_indices(expected as usize, 0)) + } + + pub(crate) fn incompatible_shapes(expected: &D, actual: &E) -> ShapeError + where + D: Dimension, + E: Dimension, + { + Self::from_kind_info(ErrorKind::IncompatibleShape, encode_shapes(expected, actual)) + } + + #[cfg(test)] + fn info_expected_index(&self) -> Option { + let (exp, _) = decode_indices(self.info); + exp + } + + #[cfg(test)] + fn info_actual_index(&self) -> Option { + let (_, actual) = decode_indices(self.info); + actual + } + + #[cfg(test)] + fn decode_shapes(&self) -> (Option, Option) { + decode_shapes(self.info) } } @@ -38,12 +91,15 @@ impl ShapeError { #[derive(Copy, Clone, Debug)] pub enum ErrorKind { /// incompatible shape + // encodes info: expected and actual shape IncompatibleShape = 1, /// incompatible memory layout + // encodes info: expected layout IncompatibleLayout, /// the shape does not fit inside type limits RangeLimited, /// out of bounds indexing + // encodes info: expected and actual index OutOfBounds, /// aliasing array elements Unsupported, @@ -52,8 +108,8 @@ pub enum ErrorKind { } #[inline(always)] -pub fn from_kind(k: ErrorKind) -> ShapeError { - ShapeError { repr: k } +pub fn from_kind(error: ErrorKind) -> ShapeError { + ShapeError::from_kind(error) } impl PartialEq for ErrorKind { @@ -83,7 +139,7 @@ impl fmt::Display for ShapeError { ErrorKind::Unsupported => "unsupported operation", ErrorKind::Overflow => "arithmetic overflow", }; - write!(f, "ShapeError/{:?}: {}", self.kind(), description) + write!(f, "ShapeError/{:?}: {}{}", self.kind(), description, ExtendedInfo(&self)) } } @@ -93,10 +149,463 @@ impl fmt::Debug for ShapeError { } } -pub fn incompatible_shapes(_a: &D, _b: &E) -> ShapeError +pub(crate) enum ExpectedLayout { + ContiguousCF = 1, + Unused, +} + +impl From> for ExpectedLayout { + #[inline] + fn from(x: Option) -> Self { + match x { + Some(1) => ExpectedLayout::ContiguousCF, + _ => ExpectedLayout::Unused, + } + } +} + +pub(crate) fn incompatible_shapes(_a: &D, _b: &E) -> ShapeError where D: Dimension, E: Dimension, { - from_kind(ErrorKind::IncompatibleShape) + ShapeError::incompatible_shapes(_a, _b) +} + +/// The InfoType encodes extra information per error kind, for example expected/actual axis for a +/// given error site, or expected layout for a layout error. +/// +/// It uses a custom and fixed-width (very limited) encoding; in some cases it skips filling in +/// information because it doesn't fit. +/// +/// Two bits in the first byte are reserved for EncodedInformationType, the rest is used +/// for situation-specific encoding of extra info. +/// If the first byte is zero, it shows that there is no encoded info. +type InfoType = [u8; INFO_TYPE_LEN]; + +const INFO_TYPE_LEN: usize = 15; +const INFO_BYTES: usize = INFO_TYPE_LEN - 1; + +const fn info_default() -> InfoType { [0; INFO_TYPE_LEN] } + +#[repr(u8)] +// 2 bits +enum EncodedInformationType { + Nothing = 0, + Expected = 0b1, + Actual = 0b10, +} + +const IXBYTES: usize = INFO_BYTES / 2; + +fn encode_index(x: usize) -> Option<[u8; IXBYTES]> { + let bits = size_of::() * 8; + let used_bits = bits - x.leading_zeros() as usize; + if used_bits > IXBYTES * 8 { + None + } else { + let bytes = x.to_le_bytes(); + let mut result = [0; IXBYTES]; + let len = bytes.len().min(result.len()); + result[..len].copy_from_slice(&bytes[..len]); + Some(result) + } +} + +fn decode_index(x: &[u8]) -> usize { + let mut bytes = 0usize.to_le_bytes(); + let len = x.len().min(bytes.len()); + bytes[..len].copy_from_slice(&x[..len]); + usize::from_le_bytes(bytes) +} + +fn encode_indices(expected: usize, actual: usize) -> InfoType { + let eexp = encode_index(expected); + let eact = encode_index(actual); + let mut info = info_default(); + let mut info_type = EncodedInformationType::Nothing as u8; + if eexp.is_some() { + info_type |= EncodedInformationType::Expected as u8; + } + let (ebytes, abytes) = info[1..].split_at_mut(IXBYTES); + if let Some(exp) = eexp { + ebytes.copy_from_slice(&exp); + } + if eact.is_some() { + info_type |= EncodedInformationType::Actual as u8; + } + if let Some(act) = eact { + abytes.copy_from_slice(&act); + } + info[0] = info_type; + info +} + +fn decode_indices(info: InfoType) -> (Option, Option) { + let (ebytes, abytes) = info[1..].split_at(IXBYTES); + ( + if info[0] & (EncodedInformationType::Expected as u8) != 0 { + Some(decode_index(ebytes)) + } else { None }, + if info[0] & (EncodedInformationType::Actual as u8) != 0 { + Some(decode_index(abytes)) + } else { None }, + ) + +} + +fn encode_shapes(expected: &D, actual: &E) -> InfoType +where + D: Dimension, + E: Dimension, +{ + encode_shapes_impl(expected.slice(), actual.slice()) +} + +// Shape encoding +// +// 15 bytes to use +// 1 byte: +// 1 bit expected shape is two-byte encoded yes/no (a) +// 3 bits expected shape len (len 0..=7) (b) +// 1 bit actual shape is two-byte encoded yes/no (c) +// 3 bits actual shape len (len 0..=7) (d) +// 14 bytes encoding of expected shape and actual shape: +// X bytes for expected: X = (a + 1) * (b) bytes +// then directly following: +// Y bytes for actual: Y = (c + 1) * (d) bytes + +const SHAPE_MAX_LEN: usize = (1 << 3) - 1; + +struct ShapeEncoding { + len: usize, + element_width: usize, + data: [u8; INFO_BYTES - 1], +} + +#[derive(Copy, Clone)] +enum EncodingWidth { + One = 1, + Two = 2, +} + +fn encode_shape(shape: &[usize], use_width: EncodingWidth) -> ShapeEncoding { + let mut info = [0; INFO_BYTES - 1]; + match use_width { + EncodingWidth::One => { + for (i, &d) in enumerate(shape) { + debug_assert!(d < 256); + info[i] = d as u8; + } + ShapeEncoding { + len: shape.len(), + element_width: 1, + data: info, + } + } + EncodingWidth::Two => { + for (i, &d) in enumerate(shape) { + debug_assert!(d < 256 * 256); + let dbytes = d.to_le_bytes(); + info[2 * i] = dbytes[0]; + info[2 * i + 1] = dbytes[1]; + } + ShapeEncoding { + len: shape.len(), + element_width: 2, + data: info, + } + } + } +} + +fn encode_shapes_impl(expected: &[usize], actual: &[usize]) -> InfoType { + let exp_onebyte = expected.iter().all(|&i| i < 256); + let exp_fit = exp_onebyte && expected.len() <= SHAPE_MAX_LEN || + expected.iter().all(|&i| i < 256 * 256) && expected.len() <= (INFO_BYTES - 1) / 2; + let act_onebyte = actual.iter().all(|&i| i < 256); + + let mut info = info_default(); + let mut info_type = EncodedInformationType::Nothing as u8; + let mut shape_header = 0; + + let mut remaining_len = INFO_BYTES - 1; + if exp_fit { + info_type |= EncodedInformationType::Expected as u8; + let eexp = encode_shape(expected, if exp_onebyte { EncodingWidth::One } else { EncodingWidth::Two }); + shape_header |= (!exp_onebyte as u8) << 0; + + info[2..].copy_from_slice(&eexp.data[..]); + remaining_len -= eexp.len * eexp.element_width; + shape_header |= (eexp.len as u8) << 1; + } + + if remaining_len > 0 { + if (act_onebyte && remaining_len >= actual.len()) || + remaining_len / 2 >= actual.len() + { + info_type |= EncodedInformationType::Actual as u8; + let eact = encode_shape(actual, if act_onebyte { EncodingWidth::One } else { EncodingWidth::Two }); + shape_header |= (!act_onebyte as u8) << 4; + let data_start = INFO_BYTES - 1 - remaining_len; + + info[2 + data_start..].copy_from_slice(&eact.data[..remaining_len]); + shape_header |= (eact.len as u8) << 5; + } else { + // skip encoding + } + } + info[0] = info_type; + info[1] = shape_header; + info +} + +#[derive(Default)] +#[cfg_attr(test, derive(Debug))] +struct DecodedShape { + len: usize, + shape: [usize; 8], +} + +impl DecodedShape { + fn as_slice(&self) -> &[usize] { + &self.shape[..self.len] + } +} + +fn decode_shape(data: &[u8], len: usize, width: EncodingWidth) -> DecodedShape { + debug_assert!(len * (width as usize) <= data.len(), + "Too short data when decoding shape"); + let mut shape = DecodedShape { len, ..<_>::default() }; + match width { + EncodingWidth::One => { + for (i, &d) in (0..len).zip(data) { + shape.shape[i] = d as usize; + } + } + EncodingWidth::Two => { + for i in 0..len { + let mut bytes = 0usize.to_le_bytes(); + bytes[0] = data[2 * i]; + bytes[1] = data[2 * i + 1]; + shape.shape[i] = usize::from_le_bytes(bytes); + } + } + } + shape +} + +fn decode_shapes(info: InfoType) -> (Option, Option) { + let exp_present = info[0] & (EncodedInformationType::Expected as u8) != 0; + let act_present = info[0] & (EncodedInformationType::Actual as u8) != 0; + let exp_twobyte = ((info[1] >> 0) & 0b1) != 0; + let act_twobyte = ((info[1] >> 4) & 0b1) != 0; + let exp_len_mask = if !act_present { !0u8 } else { (1u8 << 3) - 1 }; + let exp_len = ((info[1] >> 1) & exp_len_mask) as usize; + let act_len = ((info[1] >> 5) & 0b111) as usize; + let mut start = 2; + let exp = if exp_present { + let width = if !exp_twobyte { EncodingWidth::One } else { EncodingWidth::Two }; + let exp = decode_shape(&info[start..], exp_len, width); + start += exp_len * width as usize; + Some(exp) + } else { None }; + let act = if act_present { + let width = if !act_twobyte { EncodingWidth::One } else { EncodingWidth::Two }; + let act = decode_shape(&info[start..], act_len, width); + Some(act) + } else { None }; + (exp, act) +} + +#[derive(Copy, Clone)] +struct ExtendedInfo<'a>(&'a ShapeError); + +impl<'a> ExtendedInfo<'a> { + fn has_info(&self) -> bool { + self.0.info[0] != EncodedInformationType::Nothing as u8 + } +} + +impl<'a> fmt::Display for ExtendedInfo<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.has_info() { + return Ok(()); + } + // Use the wording of "expected: X, but got: Y" for + // expected and actual part of the error exented info. + match self.0.kind() { + ErrorKind::IncompatibleLayout => { + let (expected, _) = decode_indices(self.0.info); + match ExpectedLayout::from(expected) { + ExpectedLayout::ContiguousCF => { + write!(f, "; expected c- or f-contiguous input")?; + } + ExpectedLayout::Unused => {} + } + } + ErrorKind::IncompatibleShape => { + let (expected, actual) = decode_shapes(self.0.info); + write!(f, "; expected compatible: ")?; + if let Some(value) = expected { + write!(f, "{:?}", value.as_slice())?; + } else { + write!(f, "unknown")?; + } + if let Some(value) = actual { + write!(f, ", but got: {:?}", value.as_slice())?; + } else { + write!(f, "unknown")?; + } + } + _otherwise => { + let (expected, actual) = decode_indices(self.0.info); + write!(f, "; expected: ")?; + if let Some(value) = expected { + write!(f, "{}", value)?; + } else { + write!(f, "unknown")?; + } + + write!(f, ", but got: ")?; + if let Some(value) = actual { + write!(f, "{}", value)?; + } else { + write!(f, "unknown")?; + } + } + } + Ok(()) + } +} + + +#[cfg(test)] +use matches::assert_matches; +#[cfg(test)] +use crate::IntoDimension; + +#[test] +fn test_sizes() { + assert!(size_of::() <= size_of::()); + assert!(size_of::() <= 16); + + assert_eq!(size_of::>(), size_of::()); +} + +#[test] +fn test_encode_decode_format() { + use alloc::string::ToString; + + assert_eq!( + ShapeError::invalid_axis(1, 0).to_string(), + "ShapeError/OutOfBounds: out of bounds indexing; expected: 1, but got: 0"); + + if size_of::() > 4 { + assert_eq!( + ShapeError::invalid_axis(usize::MAX, usize::MAX).to_string(), + "ShapeError/OutOfBounds: out of bounds indexing"); + } + + assert_eq!( + ShapeError::incompatible_shapes(&(1, 2, 3).into_dimension(), &(2, 3).into_dimension()) + .to_string(), + "ShapeError/IncompatibleShape: \ + incompatible shapes; expected compatible: [1, 2, 3], but got: [2, 3]"); +} + +#[test] +fn test_encode_decode() { + for &i in [0, 1, 2, 3, 10, 32, 256, 1736, 16300].iter() { + let err = ShapeError::invalid_axis(i, 0); + assert_eq!(err.info_expected_index(), Some(i)); + let err = ShapeError::invalid_axis(0, i); + assert_eq!(err.info_actual_index(), Some(i)); + } + + let err = ShapeError::invalid_axis(1 << 24, (1 << 24) + 1); + assert_eq!(err.info_expected_index(), Some(1 << 24)); + assert_eq!(err.info_actual_index(), Some((1 << 24) + 1)); + + if size_of::() > 4 { + // use .wrapping_shl(_) for portability + let err = ShapeError::invalid_axis(1usize.wrapping_shl(56) - 1, 0); + assert_eq!(err.info_expected_index(), Some(1usize.wrapping_shl(56) - 1)); + assert_eq!(err.info_actual_index(), Some(0)); + + let err = ShapeError::invalid_axis(1usize.wrapping_shl(56), 1usize.wrapping_shl(56)); + assert_eq!(err.info_expected_index(), None); + assert_eq!(err.info_actual_index(), None); + + let err = ShapeError::invalid_axis(usize::MAX, usize::MAX); + assert_eq!(err.info_expected_index(), None); + assert_eq!(err.info_actual_index(), None); + } +} + +#[test] +fn test_encode_decode_shape() { + let err = ShapeError::incompatible_shapes(&(1, 2).into_dimension(), &(4, 5).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2]); + assert_eq!(act.unwrap().as_slice(), &[4, 5]); + + let err = ShapeError::incompatible_shapes(&(1, 2, 3).into_dimension(), &(4, 5, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2, 3]); + assert_eq!(act.unwrap().as_slice(), &[4, 5, 6]); + + let err = ShapeError::incompatible_shapes(&().into_dimension(), &().into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[]); + assert_eq!(act.unwrap().as_slice(), &[]); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&(m, n).into_dimension(), &(m + 1, n + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m, n]); + assert_eq!(act.unwrap().as_slice(), &[m + 1, n + 1]); + //assert!(act.is_none()); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&(m, n).into_dimension(), &(m + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m, n]); + assert_eq!(act.unwrap().as_slice(), &[m + 1]); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&m.into_dimension(), &(m + 1, n + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m]); + assert_eq!(act.unwrap().as_slice(), &[m + 1, n + 1]); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024).into_dimension(), &(4, 500, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024]); + assert_eq!(act.unwrap().as_slice(), &[4, 500, 6]); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024, 3, 300).into_dimension(), &(4, 500, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024, 3, 300]); + assert_matches!(act, None); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024, 3, 300).into_dimension(), &(4, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024, 3, 300]); + assert_eq!(act.unwrap().as_slice(), &[4, 6]); + + let err = ShapeError::incompatible_shapes(&().into_dimension(), &(768, 2, 1024).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[]); + assert_eq!(act.unwrap().as_slice(), &[768, 2, 1024]); + + let err = ShapeError::incompatible_shapes(&[1, 2, 3, 4, 5, 6, 7, 8].into_dimension(), &().into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_matches!(exp, None); + assert_eq!(act.unwrap().as_slice(), &[]); + + let err = ShapeError::incompatible_shapes(&[1, 2, 3, 4, 5, 6, 7].into_dimension(), &(1, 2).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2, 3, 4, 5, 6, 7]); + assert_eq!(act.unwrap().as_slice(), &[1, 2]); } diff --git a/tests/array.rs b/tests/array.rs index 8e084e49e..87cca71a6 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1377,7 +1377,9 @@ fn reshape() { fn reshape_error1() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); - let _u = v.into_shape((2, 5)).unwrap(); + let res = v.into_shape((2, 5)); + println!("{:?}", res); + res.unwrap(); } #[test] @@ -1387,7 +1389,9 @@ fn reshape_error2() { let v = aview1(&data); let mut u = v.into_shape((2, 2, 2)).unwrap(); u.swap_axes(0, 1); - let _s = u.into_shape((2, 4)).unwrap(); + let res = u.into_shape((2, 4)); + println!("{:?}", res); + res.unwrap(); } #[test] diff --git a/tests/stacking.rs b/tests/stacking.rs index 032525ffa..cd262dfe5 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -16,12 +16,15 @@ fn concatenating() { assert_eq!(d, aview1(&[2., 2., 9., 9.])); let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } @@ -36,11 +39,14 @@ fn stacking() { let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } From 5f34c4ac868a3089ccabe58a2e353831e289d72a Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 29 Mar 2021 20:15:05 +0200 Subject: [PATCH 3/3] FIX: Add extra information to errors where possible Where possible, add expected/actual information to the ShapeError. In many places it is identified new places where more specific ErrorKinds and error messages are needed. These are not updated here - a comment is inserted - this will be updated in a future version, when we can accept breaking changes. --- src/dimension/broadcast.rs | 1 + src/dimension/mod.rs | 16 ++++++++++++---- src/impl_methods.rs | 14 +++++++++----- src/slice.rs | 2 ++ src/stacking.rs | 24 ++++++++++++++++-------- 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index dc1513f04..a23d24921 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -27,6 +27,7 @@ where if *out == 1 { *out = *s2 } else if *s2 != 1 { + // TODO More specific error axis length mismatch return Err(from_kind(ErrorKind::IncompatibleShape)); } } diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index f4f46e764..057e39aed 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -89,6 +89,7 @@ pub fn size_of_shape_checked(dim: &D) -> Result .try_fold(1usize, |acc, &d| acc.checked_mul(d)) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; if size_nonzero > ::std::isize::MAX as usize { + // TODO More specific error Err(from_kind(ErrorKind::Overflow)) } else { Ok(dim.size()) @@ -137,7 +138,7 @@ pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) let len = size_of_shape_checked(dim)?; // Condition 2. if len > data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len, len)); } Ok(()) } @@ -170,6 +171,7 @@ where { // Condition 1. if dim.ndim() != strides.ndim() { + // TODO More specific error for dimension stride dimensionality mismatch return Err(from_kind(ErrorKind::IncompatibleLayout)); } @@ -185,9 +187,11 @@ where let off = d.saturating_sub(1).checked_mul(s.abs() as usize)?; acc.checked_add(off) }) + // TODO More specific error .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2a. if max_offset > isize::MAX as usize { + // TODO More specific error return Err(from_kind(ErrorKind::Overflow)); } @@ -195,9 +199,11 @@ where // greatest address accessible by moving along all axes let max_offset_bytes = max_offset .checked_mul(elem_size) + // TODO More specific error .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2b. if max_offset_bytes > isize::MAX as usize { + // TODO More specific error return Err(from_kind(ErrorKind::Overflow)); } @@ -256,15 +262,16 @@ fn can_index_slice_impl( // Check condition 3. let is_empty = dim.slice().iter().any(|&d| d == 0); if is_empty && max_offset > data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len, max_offset)); } if !is_empty && max_offset >= data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len.wrapping_sub(1), max_offset)); } // Check condition 4. if !is_empty && dim_stride_overlap(dim, strides) { - return Err(from_kind(ErrorKind::Unsupported)); + // TODO: More specific error kind Strides result in overlapping elements + return Err(ShapeError::from_kind(ErrorKind::Unsupported)); } Ok(()) @@ -293,6 +300,7 @@ where { for &stride in strides.slice() { if (stride as isize) < 0 { + // TODO: More specific error kind Non-negative strides required return Err(from_kind(ErrorKind::Unsupported)); } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4045f2b59..82cc1cc0a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -23,7 +23,7 @@ use crate::dimension::{ offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; -use crate::error::{self, ErrorKind, ShapeError, from_kind}; +use crate::error::{self, ErrorKind, ShapeError}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::{IntoNdProducer, Zip}; @@ -1588,7 +1588,7 @@ where } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { Ok(self.with_strides_dim(shape.fortran_strides(), shape)) } else { - Err(error::from_kind(error::ErrorKind::IncompatibleLayout)) + Err(ShapeError::incompatible_layout(error::ExpectedLayout::ContiguousCF)) } } } @@ -1693,6 +1693,7 @@ where } } } + // TODO More specific error incompatible ndim Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) } @@ -1805,11 +1806,14 @@ where { let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; if let Some(view1) = self.broadcast(shape.clone()) { - if let Some(view2) = other.broadcast(shape) { - return Ok((view1, view2)); + if let Some(view2) = other.broadcast(shape.clone()) { + Ok((view1, view2)) + } else { + Err(ShapeError::incompatible_shapes(&other.dim, &shape)) } + } else { + Err(ShapeError::incompatible_shapes(&other.dim, &shape)) } - Err(from_kind(ErrorKind::IncompatibleShape)) } /// Swap axes `ax` and `bx`. diff --git a/src/slice.rs b/src/slice.rs index 3c554a5ca..da76cda74 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -415,11 +415,13 @@ where { if let Some(in_ndim) = Din::NDIM { if in_ndim != indices.in_ndim() { + // TODO More specific error incompatible ndim return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } if let Some(out_ndim) = Dout::NDIM { if out_ndim != indices.out_ndim() { + // TODO More specific error incompatible ndim return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } diff --git a/src/stacking.rs b/src/stacking.rs index 500ded6af..213e2ac1f 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -8,6 +8,7 @@ use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::imp_prelude::*; +use crate::NdProducer; /// Stack arrays along the new axis. /// @@ -72,18 +73,22 @@ where D: RemoveAxis, { if arrays.is_empty() { + // TODO More specific error for empty input not supported return Err(from_kind(ErrorKind::Unsupported)); } let mut res_dim = arrays[0].raw_dim(); if axis.index() >= res_dim.ndim() { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::invalid_axis(res_dim.ndim().wrapping_sub(1), axis.index())); } let common_dim = res_dim.remove_axis(axis); - if arrays - .iter() - .any(|a| a.raw_dim().remove_axis(axis) != common_dim) + if let Some(a) = arrays.iter().find_map(|a| + if a.raw_dim().remove_axis(axis) != common_dim { + Some(a) + } else { + None + }) { - return Err(from_kind(ErrorKind::IncompatibleShape)); + return Err(ShapeError::incompatible_shapes(&common_dim, &a.dim)); } let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis)); @@ -143,17 +148,20 @@ where D::Larger: RemoveAxis, { if arrays.is_empty() { + // TODO More specific error for empty input not supported return Err(from_kind(ErrorKind::Unsupported)); } let common_dim = arrays[0].raw_dim(); // Avoid panic on `insert_axis` call, return an Err instead of it. if axis.index() > common_dim.ndim() { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::invalid_axis(common_dim.ndim(), axis.index())); } let mut res_dim = common_dim.insert_axis(axis); - if arrays.iter().any(|a| a.raw_dim() != common_dim) { - return Err(from_kind(ErrorKind::IncompatibleShape)); + if let Some(array) = arrays.iter().find_map(|array| if !array.equal_dim(&common_dim) { + Some(array) + } else { None }) { + return Err(ShapeError::incompatible_shapes(&common_dim, &array.dim)); } res_dim.set_axis(axis, arrays.len());