diff --git a/Project.toml b/Project.toml index 24fcf73..e0513bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.2.6" +version = "0.3.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/docs/Project.toml b/docs/Project.toml index 92b001e..228ca1b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" [compat] -DiagonalArrays = "0.2" +DiagonalArrays = "0.3" Documenter = "1" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index df0fc98..b559d07 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -3,5 +3,5 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -DiagonalArrays = "0.2" +DiagonalArrays = "0.3" Test = "1" diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index fcb098e..a7358c8 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -2,90 +2,103 @@ function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} return zero(eltype(a)) end +function _DiagonalArray end + struct DiagonalArray{T,N,Diag<:AbstractVector{T},F} <: AbstractDiagonalArray{T,N} diag::Diag - dims::NTuple{N,Int} - getunstoredindex::F + dims::Dims{N} + getunstored::F + global @inline function _DiagonalArray( + diag::Diag, dims::Dims{N}, getunstored::F + ) where {T,N,Diag<:AbstractVector{T},F} + all(≥(0), dims) || throw(ArgumentError("Invalid dimensions: $dims")) + length(diag) == minimum(dims) || + throw(ArgumentError("Length of diagonals doesn't match dimensions")) + return new{T,N,Diag,F}(diag, dims, getunstored) + end end function DiagonalArray{T,N}( - diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero + diag::AbstractVector, dims::Dims{N}; getunstored=getzero ) where {T,N} - return DiagonalArray{T,N,typeof(diag),typeof(getunstoredindex)}(diag, d, getunstoredindex) + return _DiagonalArray(convert(AbstractVector{T}, diag), dims, getunstored) end function DiagonalArray{T,N}( - diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero + diag::AbstractVector, dims::Vararg{Int,N}; kwargs... ) where {T,N} - return DiagonalArray{T,N}(T.(diag), d, getunstoredindex) + return DiagonalArray{T,N}(diag, dims; kwargs...) end -function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, d) +function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag, dims; kwargs...) end -function DiagonalArray{T}( - diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero +function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int,N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag, dims; kwargs...) +end + +function DiagonalArray{<:Any,N}( + diag::AbstractVector{T}, dims::Dims{N}; kwargs... ) where {T,N} - return DiagonalArray{T,N}(diag, d, getunstoredindex) + return DiagonalArray{T,N}(diag, dims; kwargs...) end -function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, d) +function DiagonalArray{<:Any,N}( + diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs... +) where {T,N} + return DiagonalArray{T,N}(diag, dims; kwargs...) end -function DiagonalArray(diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}) where {T,N} - return DiagonalArray{T,N}(diag, d) +function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag, dims; kwargs...) end -function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, d) +function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag, dims; kwargs...) end # Infer size from diagonal -function DiagonalArray{T,N}(diag::AbstractVector) where {T,N} - return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N)) +function DiagonalArray{T,N}(diag::AbstractVector; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N); kwargs...) end -function DiagonalArray{<:Any,N}(diag::AbstractVector{T}) where {T,N} - return DiagonalArray{T,N}(diag) +function DiagonalArray{<:Any,N}(diag::AbstractVector{T}; kwargs...) where {T,N} + return DiagonalArray{T,N}(diag; kwargs...) end # undef -function DiagonalArray{T,N}( - ::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero -) where {T,N} - return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, getunstoredindex) +function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims; kwargs...) end -function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(undef, d) +function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(undef, dims; kwargs...) end -function DiagonalArray{T}( - ::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero -) where {T,N} - return DiagonalArray{T,N}(undef, d, getunstoredindex) +function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N} + return DiagonalArray{T,N}(undef, dims; kwargs...) +end + +function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} + return DiagonalArray{T,N}(undef, dims) end # Axes version function DiagonalArray{T}( - ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}, getunstoredindex=getzero + ::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}; kwargs... ) where {T,N} - @assert all(isone, first.(axes)) - return DiagonalArray{T,N}(undef, length.(axes), getunstoredindex) -end - -function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(undef, d) + return DiagonalArray{T,N}(undef, length.(axes); kwargs...) end # Minimal `AbstractArray` interface Base.size(a::DiagonalArray) = a.dims function Base.similar(a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}) - # TODO: Preserve zero element function. - return DiagonalArray{elt}(undef, dims, a.getunstoredindex) + function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} + return convert(elt, a.getunstored(a, I...)) + end + return DiagonalArray{elt}(undef, dims; getunstored=getzero) end # DiagonalArrays interface. diff --git a/test/Project.toml b/test/Project.toml index 053c41f..54990e4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8.9" -DiagonalArrays = "0.2" +DiagonalArrays = "0.3" FillArrays = "1" LinearAlgebra = "1" SafeTestsets = "0.1"