diff --git a/src/bidiag.jl b/src/bidiag.jl index 2fb1415f..52cd552f 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -1461,7 +1461,13 @@ eigen(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M)) Base._sum(A::Bidiagonal, ::Colon) = sum(A.dv) + sum(A.ev) function Base._sum(A::Bidiagonal, dims::Integer) - res = Base.reducedim_initarray(A, dims, zero(eltype(A))) + Base._check_valid_region(dims) + ax = (dims == 1) ? (1, axes(A, 2)) : + (dims == 2) ? (axes(A, 1), 1) : + axes(A) + res = Base.mapreduce_similar(A, eltype(A), ax) + fill!(res, zero(eltype(A))) + n = length(A.dv) if n == 0 # Just to be sure. This shouldn't happen since there is a check whether diff --git a/src/diagonal.jl b/src/diagonal.jl index ec1cd89c..23a66187 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -1134,7 +1134,13 @@ end Base._sum(A::Diagonal, ::Colon) = sum(A.diag) function Base._sum(A::Diagonal, dims::Integer) - res = Base.reducedim_initarray(A, dims, zero(eltype(A))) + Base._check_valid_region(dims) + ax = (dims == 1) ? (1, axes(A, 2)) : + (dims == 2) ? (axes(A, 1), 1) : + axes(A) + res = Base.mapreduce_similar(A, eltype(A), ax) + fill!(res, zero(eltype(A))) + if dims <= 2 for i = 1:length(A.diag) @inbounds res[i] = A.diag[i] diff --git a/src/tridiag.jl b/src/tridiag.jl index a24cc50b..9699e207 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -893,7 +893,13 @@ function Base._sum(A::SymTridiagonal, ::Colon) end function Base._sum(A::Tridiagonal, dims::Integer) - res = Base.reducedim_initarray(A, dims, zero(eltype(A))) + Base._check_valid_region(dims) + ax = (dims == 1) ? (1, axes(A, 2)) : + (dims == 2) ? (axes(A, 1), 1) : + axes(A) + res = Base.mapreduce_similar(A, eltype(A), ax) + fill!(res, zero(eltype(A))) + n = length(A.d) if n == 0 return res @@ -927,7 +933,13 @@ function Base._sum(A::Tridiagonal, dims::Integer) end function Base._sum(A::SymTridiagonal, dims::Integer) - res = Base.reducedim_initarray(A, dims, zero(eltype(A))) + Base._check_valid_region(dims) + ax = (dims == 1) ? (1, axes(A, 2)) : + (dims == 2) ? (axes(A, 1), 1) : + axes(A) + res = Base.mapreduce_similar(A, eltype(A), ax) + fill!(res, zero(eltype(A))) + n = length(A.dv) if n == 0 return res