@@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6
6
///
7
7
/// Uses the [NumPy broadcasting rules]
8
8
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9
- fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10
- where
11
- D1 : Dimension ,
12
- D2 : Dimension ,
13
- Output : Dimension ,
9
+ pub ( crate ) fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10
+ where
11
+ D1 : Dimension ,
12
+ D2 : Dimension ,
13
+ Output : Dimension ,
14
14
{
15
15
let ( k, overflow) = shape1. ndim ( ) . overflowing_sub ( shape2. ndim ( ) ) ;
16
16
// Swap the order if d2 is longer.
@@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
37
37
pub trait DimMax < Other : Dimension > {
38
38
/// The resulting dimension type after broadcasting.
39
39
type Output : Dimension ;
40
-
41
- /// Determines the shape after broadcasting the shapes together.
42
- ///
43
- /// If the shapes are not compatible, returns `Err`.
44
- fn broadcast_shape ( & self , other : & Other ) -> Result < Self :: Output , ShapeError > ;
45
40
}
46
41
47
42
/// Dimensions of the same type remain unchanged when co_broadcast.
48
43
/// So you can directly use D as the resulting type.
49
44
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
50
45
impl < D : Dimension > DimMax < D > for D {
51
46
type Output = D ;
52
-
53
- fn broadcast_shape ( & self , other : & D ) -> Result < Self :: Output , ShapeError > {
54
- co_broadcast :: < D , D , Self :: Output > ( self , other)
55
- }
56
47
}
57
48
58
49
macro_rules! impl_broadcast_distinct_fixed {
59
50
( $smaller: ty, $larger: ty) => {
60
51
impl DimMax <$larger> for $smaller {
61
52
type Output = $larger;
62
-
63
- fn broadcast_shape( & self , other: & $larger) -> Result <Self :: Output , ShapeError > {
64
- co_broadcast:: <Self , $larger, Self :: Output >( self , other)
65
- }
66
53
}
67
54
68
55
impl DimMax <$smaller> for $larger {
69
56
type Output = $larger;
70
-
71
- fn broadcast_shape( & self , other: & $smaller) -> Result <Self :: Output , ShapeError > {
72
- co_broadcast:: <Self , $smaller, Self :: Output >( self , other)
73
- }
74
57
}
75
58
} ;
76
59
}
@@ -103,3 +86,58 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
103
86
impl_broadcast_distinct_fixed ! ( Ix4 , IxDyn ) ;
104
87
impl_broadcast_distinct_fixed ! ( Ix5 , IxDyn ) ;
105
88
impl_broadcast_distinct_fixed ! ( Ix6 , IxDyn ) ;
89
+
90
+
91
+ #[ cfg( test) ]
92
+ #[ cfg( feature = "std" ) ]
93
+ mod tests {
94
+ use super :: co_broadcast;
95
+ use crate :: { Dimension , Dim , DimMax , ShapeError , Ix0 , IxDynImpl , ErrorKind } ;
96
+
97
+ #[ test]
98
+ fn test_broadcast_shape ( ) {
99
+ fn test_co < D1 , D2 > (
100
+ d1 : & D1 ,
101
+ d2 : & D2 ,
102
+ r : Result < <D1 as DimMax < D2 > >:: Output , ShapeError > ,
103
+ ) where
104
+ D1 : Dimension + DimMax < D2 > ,
105
+ D2 : Dimension ,
106
+ {
107
+ let d = co_broadcast :: < D1 , D2 , <D1 as DimMax < D2 > >:: Output > ( & d1, d2) ;
108
+ assert_eq ! ( d, r) ;
109
+ }
110
+ test_co ( & Dim ( [ 2 , 3 ] ) , & Dim ( [ 4 , 1 , 3 ] ) , Ok ( Dim ( [ 4 , 2 , 3 ] ) ) ) ;
111
+ test_co (
112
+ & Dim ( [ 1 , 2 , 2 ] ) ,
113
+ & Dim ( [ 1 , 3 , 4 ] ) ,
114
+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
115
+ ) ;
116
+ test_co ( & Dim ( [ 3 , 4 , 5 ] ) , & Ix0 ( ) , Ok ( Dim ( [ 3 , 4 , 5 ] ) ) ) ;
117
+ let v = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ;
118
+ test_co (
119
+ & Dim ( vec ! [ 1 , 1 , 3 , 1 , 5 , 1 , 7 ] ) ,
120
+ & Dim ( [ 2 , 1 , 4 , 1 , 6 , 1 ] ) ,
121
+ Ok ( Dim ( IxDynImpl :: from ( v. as_slice ( ) ) ) ) ,
122
+ ) ;
123
+ let d = Dim ( [ 1 , 2 , 1 , 3 ] ) ;
124
+ test_co ( & d, & d, Ok ( d) ) ;
125
+ test_co (
126
+ & Dim ( [ 2 , 1 , 2 ] ) . into_dyn ( ) ,
127
+ & Dim ( 0 ) ,
128
+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
129
+ ) ;
130
+ test_co (
131
+ & Dim ( [ 2 , 1 , 1 ] ) ,
132
+ & Dim ( [ 0 , 0 , 1 , 3 , 4 ] ) ,
133
+ Ok ( Dim ( [ 0 , 0 , 2 , 3 , 4 ] ) ) ,
134
+ ) ;
135
+ test_co ( & Dim ( [ 0 ] ) , & Dim ( [ 0 , 0 , 0 ] ) , Ok ( Dim ( [ 0 , 0 , 0 ] ) ) ) ;
136
+ test_co ( & Dim ( 1 ) , & Dim ( [ 1 , 0 , 0 ] ) , Ok ( Dim ( [ 1 , 0 , 0 ] ) ) ) ;
137
+ test_co (
138
+ & Dim ( [ 1 , 3 , 0 , 1 , 1 ] ) ,
139
+ & Dim ( [ 1 , 2 , 3 , 1 ] ) ,
140
+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
141
+ ) ;
142
+ }
143
+ }
0 commit comments