@@ -61,10 +61,31 @@ unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
61
61
unpack_args (:: Any , args:: Tuple{} ) = ()
62
62
63
63
nnodes (A) = 0
64
- nnodes (A:: ASMA ) = length (A. nodes)
64
+ nnodes (A:: AMSA ) = length (A. nodes)
65
65
nnodes (bc:: Broadcast.Broadcasted ) = _nnodes (bc. args)
66
66
nnodes (A, Bs... ) = common_number (nnodes (A), _nnodes (Bs))
67
67
68
68
@inline _nnodes (args:: Tuple ) = common_number (nnodes (args[1 ]), _nnodes (Base. tail (args)))
69
69
_nnodes (args:: Tuple{Any} ) = nnodes (args[1 ])
70
70
_nnodes (args:: Tuple{} ) = 0
71
+
72
+ get_common_type (A) = Nothing
73
+ get_common_type (A:: AMSA ) = typeof (A)
74
+ get_common_type (bc:: Broadcast.Broadcasted ) = _nnodes (bc. args)
75
+ get_common_type (A, Bs... ) = common_type (get_common_type (A), _get_common_type (Bs))
76
+
77
+ @inline _get_common_type (args:: Tuple ) = get_common_type (get_common_type (args[1 ]), _get_common_type (Base. tail (args)))
78
+ _get_common_type (args:: Tuple{Any} ) = get_common_type (args[1 ])
79
+ _get_common_type (args:: Tuple{} ) = Nothing
80
+
81
+ # # utils
82
+ common_number (a, b) =
83
+ a == 0 ? b :
84
+ (b == 0 ? a :
85
+ (a == b ? a :
86
+ throw (DimensionMismatch (" number of nodes must be equal" ))))
87
+
88
+ common_type (a:: T , b:: T ) where T = T
89
+ common_type (a:: T , b:: Nothing ) where T = T
90
+ common_type (a:: Nothing , b:: T ) where T = T
91
+ common_type (a:: Nothing , b:: Nothing ) where T = Nothing
0 commit comments