Skip to content

Commit b05115e

Browse files
common type and node calculations
1 parent 6772b7e commit b05115e

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/math.jl

+22-1
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,31 @@ unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
6161
unpack_args(::Any, args::Tuple{}) = ()
6262

6363
nnodes(A) = 0
64-
nnodes(A::ASMA) = length(A.nodes)
64+
nnodes(A::AMSA) = length(A.nodes)
6565
nnodes(bc::Broadcast.Broadcasted) = _nnodes(bc.args)
6666
nnodes(A, Bs...) = common_number(nnodes(A), _nnodes(Bs))
6767

6868
@inline _nnodes(args::Tuple) = common_number(nnodes(args[1]), _nnodes(Base.tail(args)))
6969
_nnodes(args::Tuple{Any}) = nnodes(args[1])
7070
_nnodes(args::Tuple{}) = 0
71+
72+
get_common_type(A) = Nothing
73+
get_common_type(A::AMSA) = typeof(A)
74+
get_common_type(bc::Broadcast.Broadcasted) = _nnodes(bc.args)
75+
get_common_type(A, Bs...) = common_type(get_common_type(A), _get_common_type(Bs))
76+
77+
@inline _get_common_type(args::Tuple) = get_common_type(get_common_type(args[1]), _get_common_type(Base.tail(args)))
78+
_get_common_type(args::Tuple{Any}) = get_common_type(args[1])
79+
_get_common_type(args::Tuple{}) = Nothing
80+
81+
## utils
82+
common_number(a, b) =
83+
a == 0 ? b :
84+
(b == 0 ? a :
85+
(a == b ? a :
86+
throw(DimensionMismatch("number of nodes must be equal"))))
87+
88+
common_type(a::T, b::T) where T = T
89+
common_type(a::T, b::Nothing) where T = T
90+
common_type(a::Nothing, b::T) where T = T
91+
common_type(a::Nothing, b::Nothing) where T = Nothing

0 commit comments

Comments
 (0)