diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index 1a4300d..3e7236f 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -13,8 +13,17 @@ function Base.copy(sc::SparseCoefficients) return SparseCoefficients(copy(keys(sc)), copy(values(sc))) end +function _search(keys::Tuple, key) + # `searchsortedfirst` is not defined for `Tuple` + return findfirst(isequal(key), keys) +end + +function _search(keys, key::K) where {K} + return searchsortedfirst(keys, key; lt = comparable(K)) +end + function Base.getindex(sc::SparseCoefficients{K}, key::K) where {K} - k = searchsortedfirst(sc.basis_elements, key; lt = comparable(K)) + k = _search(sc.basis_elements, key) if k in eachindex(sc.basis_elements) v = sc.values[k] if sc.basis_elements[k] == key @@ -30,7 +39,7 @@ end function Base.setindex!(sc::SparseCoefficients{K}, val, key::K) where {K} k = searchsortedfirst(sc.basis_elements, key; lt = comparable(K)) if k in eachindex(sc.basis_elements) && sc.basis_elements[k] == key - sc.values[k] += val + sc.values[k] = val else insert!(sc.basis_elements, k, key) insert!(sc.values, k, val) @@ -38,6 +47,55 @@ function Base.setindex!(sc::SparseCoefficients{K}, val, key::K) where {K} return sc end +################ +# Broadcasting # +################ + +struct BroadcastStyle{K} <: Broadcast.BroadcastStyle end + +Base.broadcastable(sc::SparseCoefficients) = sc +Base.BroadcastStyle(::Type{<:SparseCoefficients{K}}) where {K} = BroadcastStyle{K}() +# Disallow mixing broadcasts. +function Base.BroadcastStyle(::BroadcastStyle, ::Base.BroadcastStyle) + return throw( + ArgumentError( + "Cannot broadcast `StarAlgebras.SparseCoefficients` with" * + " another array of different type", + ), + ) +end + +# Allow broadcasting over scalars. +function Base.BroadcastStyle(style::BroadcastStyle, ::Base.Broadcast.DefaultArrayStyle{0}) + return style +end + +# Used for broadcasting +Base.axes(sc::SparseCoefficients) = (sc.basis_elements,) +#Base.Broadcast.BroadcastStyle(::Type{<:SparseCoefficients{K}}) where {K} = SparseArrays.HigherOrderFns.SparseVecStyle() +#SparseArrays.HigherOrderFns.nonscalararg(::SparseCoefficients) = true + +# `_get_arg` and `getindex` are inspired from `JuMP.Containers.SparseAxisArray` +_getindex(x::SparseCoefficients, index) = getindex(x, index) +_getindex(x::Any, ::Any) = x +_getindex(x::Ref, ::Any) = x[] + +function _get_arg(args::Tuple, index) + return (_getindex(first(args), index), _get_arg(Base.tail(args), index)...) +end +_get_arg(::Tuple{}, _) = () + +function Base.getindex(bc::Broadcast.Broadcasted{<:BroadcastStyle}, index) + return bc.f(_get_arg(bc.args, index)...) +end + +function Base.similar(bc::Broadcast.Broadcasted{<:BroadcastStyle}, ::Type{T}) where {T} + return similar(_first_sparse_coeffs(bc.args...), T) +end + +_first_sparse_coeffs(c::SparseCoefficients, args...) = c +_first_sparse_coeffs(_, args...) = _first_sparse_coeffs(args...) + function Base.zero(sc::SparseCoefficients) return SparseCoefficients(empty(keys(sc)), empty(values(sc))) end @@ -55,7 +113,7 @@ function similar_type(::Type{SparseCoefficients{K,V,Vk,Vv}}, ::Type{T}) where {K end function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T} - return SparseCoefficients(_similar(s.basis_elements), _similar(s.values, T)) + return SparseCoefficients(collect(s.basis_elements), _similar(s.values, T)) end function MA.mutability( diff --git a/test/caching_allocations.jl b/test/caching_allocations.jl index 30a4b4a..862ec22 100644 --- a/test/caching_allocations.jl +++ b/test/caching_allocations.jl @@ -100,4 +100,16 @@ end @test _test_op(+, Z, Y) == _test_op(+, Y, Y) @test _test_op(-, Z, Z) == _test_op(*, 0, Z) @test _test_op(-, Z, Z) == _test_op(-, Y, Z) + + for X in [Y, Z] + c = coeffs(X) + res = 2 .* c + @test c .* 2 == res + @test c .+ 1 == res + @test 1 .+ c == res + @test MA.Zero() .+ c == c + @test c .+ MA.Zero() == c + err = ArgumentError("Cannot broadcast `StarAlgebras.SparseCoefficients` with another array of different type") + @test_throws err c .+ ones(3) + end end