Skip to content

Commit 38f7341

Browse files
committed
FIX: Remove broadcast_shape from the DimMax trait
While calling co_broadcast directly is less convenient, for now they are two different functions.
1 parent b39593e commit 38f7341

File tree

4 files changed

+64
-73
lines changed

4 files changed

+64
-73
lines changed

Diff for: src/dimension/broadcast.rs

+60-22
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
66
///
77
/// Uses the [NumPy broadcasting rules]
88
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9-
fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
10-
where
11-
D1: Dimension,
12-
D2: Dimension,
13-
Output: Dimension,
9+
pub(crate) fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
10+
where
11+
D1: Dimension,
12+
D2: Dimension,
13+
Output: Dimension,
1414
{
1515
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
1616
// Swap the order if d2 is longer.
@@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
3737
pub trait DimMax<Other: Dimension> {
3838
/// The resulting dimension type after broadcasting.
3939
type Output: Dimension;
40-
41-
/// Determines the shape after broadcasting the shapes together.
42-
///
43-
/// If the shapes are not compatible, returns `Err`.
44-
fn broadcast_shape(&self, other: &Other) -> Result<Self::Output, ShapeError>;
4540
}
4641

4742
/// Dimensions of the same type remain unchanged when co_broadcast.
4843
/// So you can directly use D as the resulting type.
4944
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
5045
impl<D: Dimension> DimMax<D> for D {
5146
type Output = D;
52-
53-
fn broadcast_shape(&self, other: &D) -> Result<Self::Output, ShapeError> {
54-
co_broadcast::<D, D, Self::Output>(self, other)
55-
}
5647
}
5748

5849
macro_rules! impl_broadcast_distinct_fixed {
5950
($smaller:ty, $larger:ty) => {
6051
impl DimMax<$larger> for $smaller {
6152
type Output = $larger;
62-
63-
fn broadcast_shape(&self, other: &$larger) -> Result<Self::Output, ShapeError> {
64-
co_broadcast::<Self, $larger, Self::Output>(self, other)
65-
}
6653
}
6754

6855
impl DimMax<$smaller> for $larger {
6956
type Output = $larger;
70-
71-
fn broadcast_shape(&self, other: &$smaller) -> Result<Self::Output, ShapeError> {
72-
co_broadcast::<Self, $smaller, Self::Output>(self, other)
73-
}
7457
}
7558
};
7659
}
@@ -103,3 +86,58 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
10386
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
10487
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
10588
impl_broadcast_distinct_fixed!(Ix6, IxDyn);
89+
90+
91+
#[cfg(test)]
92+
#[cfg(feature = "std")]
93+
mod tests {
94+
use super::co_broadcast;
95+
use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind};
96+
97+
#[test]
98+
fn test_broadcast_shape() {
99+
fn test_co<D1, D2>(
100+
d1: &D1,
101+
d2: &D2,
102+
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
103+
) where
104+
D1: Dimension + DimMax<D2>,
105+
D2: Dimension,
106+
{
107+
let d = co_broadcast::<D1, D2, <D1 as DimMax<D2>>::Output>(&d1, d2);
108+
assert_eq!(d, r);
109+
}
110+
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
111+
test_co(
112+
&Dim([1, 2, 2]),
113+
&Dim([1, 3, 4]),
114+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
115+
);
116+
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
117+
let v = vec![1, 2, 3, 4, 5, 6, 7];
118+
test_co(
119+
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
120+
&Dim([2, 1, 4, 1, 6, 1]),
121+
Ok(Dim(IxDynImpl::from(v.as_slice()))),
122+
);
123+
let d = Dim([1, 2, 1, 3]);
124+
test_co(&d, &d, Ok(d));
125+
test_co(
126+
&Dim([2, 1, 2]).into_dyn(),
127+
&Dim(0),
128+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
129+
);
130+
test_co(
131+
&Dim([2, 1, 1]),
132+
&Dim([0, 0, 1, 3, 4]),
133+
Ok(Dim([0, 0, 2, 3, 4])),
134+
);
135+
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
136+
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
137+
test_co(
138+
&Dim([1, 3, 0, 1, 1]),
139+
&Dim([1, 2, 3, 1]),
140+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
141+
);
142+
}
143+
}

Diff for: src/dimension/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use std::mem;
2929
mod macros;
3030
mod axes;
3131
mod axis;
32-
mod broadcast;
32+
pub(crate) mod broadcast;
3333
mod conversion;
3434
pub mod dim;
3535
mod dimension_trait;

Diff for: src/impl_methods.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::dimension::{
2121
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
2222
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2323
};
24+
use crate::dimension::broadcast::co_broadcast;
2425
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2526
use crate::math_cell::MathCell;
2627
use crate::itertools::zip;
@@ -1778,7 +1779,7 @@ where
17781779
D: Dimension + DimMax<E>,
17791780
E: Dimension,
17801781
{
1781-
let shape = self.dim.broadcast_shape(&other.dim)?;
1782+
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
17821783
if let Some(view1) = self.broadcast(shape.clone()) {
17831784
if let Some(view2) = other.broadcast(shape) {
17841785
return Ok((view1, view2))

Diff for: tests/dimension.rs

+1-49
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
use defmac::defmac;
44

5-
use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis,
6-
ErrorKind, ShapeError, DimMax};
5+
use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis};
76

87
use std::hash::{Hash, Hasher};
98

@@ -341,50 +340,3 @@ fn test_all_ndindex() {
341340
ndindex!(10, 4, 3, 2, 2);
342341
ndindex!(10, 4, 3, 2, 2, 2);
343342
}
344-
345-
#[test]
346-
fn test_broadcast_shape() {
347-
fn test_co<D1, D2>(
348-
d1: &D1,
349-
d2: &D2,
350-
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
351-
) where
352-
D1: Dimension + DimMax<D2>,
353-
D2: Dimension,
354-
{
355-
let d = d1.broadcast_shape(d2);
356-
assert_eq!(d, r);
357-
}
358-
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
359-
test_co(
360-
&Dim([1, 2, 2]),
361-
&Dim([1, 3, 4]),
362-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
363-
);
364-
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
365-
let v = vec![1, 2, 3, 4, 5, 6, 7];
366-
test_co(
367-
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
368-
&Dim([2, 1, 4, 1, 6, 1]),
369-
Ok(Dim(IxDynImpl::from(v.as_slice()))),
370-
);
371-
let d = Dim([1, 2, 1, 3]);
372-
test_co(&d, &d, Ok(d));
373-
test_co(
374-
&Dim([2, 1, 2]).into_dyn(),
375-
&Dim(0),
376-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
377-
);
378-
test_co(
379-
&Dim([2, 1, 1]),
380-
&Dim([0, 0, 1, 3, 4]),
381-
Ok(Dim([0, 0, 2, 3, 4])),
382-
);
383-
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
384-
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
385-
test_co(
386-
&Dim([1, 3, 0, 1, 1]),
387-
&Dim([1, 2, 3, 1]),
388-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
389-
);
390-
}

0 commit comments

Comments
 (0)