diff --git a/thriller-core/src/buffer.rs b/thriller-core/src/buffer.rs index 27bffd1..d439f4e 100644 --- a/thriller-core/src/buffer.rs +++ b/thriller-core/src/buffer.rs @@ -2,7 +2,7 @@ use crate::shape::Ix; use crate::{next_id, Dim, Layout, Shape}; /// Buffer type. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum BufType { /// Global Tile GlobalTile, diff --git a/thriller-core/src/shape.rs b/thriller-core/src/shape.rs index 23633a1..581c4c6 100644 --- a/thriller-core/src/shape.rs +++ b/thriller-core/src/shape.rs @@ -25,7 +25,7 @@ pub trait Dimension { } /// Stride description. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Layout { /// Row-major RowMajor, @@ -54,7 +54,7 @@ where /// /// [`Dim`] describes the number of axes and the length of each axis /// in an array. It is also used as an index type. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct Dim { dims: SmallVec<[Ix; 4]>, ndim: usize, @@ -140,7 +140,7 @@ impl Dimension for Dim { } /// Shape description. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct Shape { dims: Dim, layout: Layout, diff --git a/thriller-core/src/task/compute/map.rs b/thriller-core/src/task/compute/map.rs index 1ad75f1..1a0c5da 100644 --- a/thriller-core/src/task/compute/map.rs +++ b/thriller-core/src/task/compute/map.rs @@ -21,6 +21,14 @@ impl Convert { src_type: DataType, dst_type: DataType, ) -> Self { + // `src_buf` and `dst_buf` must have the same typing. + // TODO(KuanjuX): Don't use assert here. + assert_eq!(src_buf.get_typing(), dst_buf.get_typing()); + + // `src_buf` and `dst_buf` must have the same shape. + // TODO(KuanjuX): Don't use assert here. + assert_eq!(src_buf.get_shape(), dst_buf.get_shape()); + Self { src_buf, dst_buf,