Skip to content

Commit 81dd30b

Browse files
authored
Update constructors (#17)
1 parent ccd35c2 commit 81dd30b

File tree

5 files changed

+57
-44
lines changed

5 files changed

+57
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.6"
4+
version = "0.3.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
7-
DiagonalArrays = "0.2"
7+
DiagonalArrays = "0.3"
88
Documenter = "1"
99
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
44

55
[compat]
6-
DiagonalArrays = "0.2"
6+
DiagonalArrays = "0.3"
77
Test = "1"

src/diagonalarray/diagonalarray.jl

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,90 +2,103 @@ function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
22
return zero(eltype(a))
33
end
44

5+
function _DiagonalArray end
6+
57
struct DiagonalArray{T,N,Diag<:AbstractVector{T},F} <: AbstractDiagonalArray{T,N}
68
diag::Diag
7-
dims::NTuple{N,Int}
8-
getunstoredindex::F
9+
dims::Dims{N}
10+
getunstored::F
11+
global @inline function _DiagonalArray(
12+
diag::Diag, dims::Dims{N}, getunstored::F
13+
) where {T,N,Diag<:AbstractVector{T},F}
14+
all((0), dims) || throw(ArgumentError("Invalid dimensions: $dims"))
15+
length(diag) == minimum(dims) ||
16+
throw(ArgumentError("Length of diagonals doesn't match dimensions"))
17+
return new{T,N,Diag,F}(diag, dims, getunstored)
18+
end
919
end
1020

1121
function DiagonalArray{T,N}(
12-
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero
22+
diag::AbstractVector, dims::Dims{N}; getunstored=getzero
1323
) where {T,N}
14-
return DiagonalArray{T,N,typeof(diag),typeof(getunstoredindex)}(diag, d, getunstoredindex)
24+
return _DiagonalArray(convert(AbstractVector{T}, diag), dims, getunstored)
1525
end
1626

1727
function DiagonalArray{T,N}(
18-
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero
28+
diag::AbstractVector, dims::Vararg{Int,N}; kwargs...
1929
) where {T,N}
20-
return DiagonalArray{T,N}(T.(diag), d, getunstoredindex)
30+
return DiagonalArray{T,N}(diag, dims; kwargs...)
2131
end
2232

23-
function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
24-
return DiagonalArray{T,N}(diag, d)
33+
function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}; kwargs...) where {T,N}
34+
return DiagonalArray{T,N}(diag, dims; kwargs...)
2535
end
2636

27-
function DiagonalArray{T}(
28-
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero
37+
function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int,N}; kwargs...) where {T,N}
38+
return DiagonalArray{T,N}(diag, dims; kwargs...)
39+
end
40+
41+
function DiagonalArray{<:Any,N}(
42+
diag::AbstractVector{T}, dims::Dims{N}; kwargs...
2943
) where {T,N}
30-
return DiagonalArray{T,N}(diag, d, getunstoredindex)
44+
return DiagonalArray{T,N}(diag, dims; kwargs...)
3145
end
3246

33-
function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
34-
return DiagonalArray{T,N}(diag, d)
47+
function DiagonalArray{<:Any,N}(
48+
diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs...
49+
) where {T,N}
50+
return DiagonalArray{T,N}(diag, dims; kwargs...)
3551
end
3652

37-
function DiagonalArray(diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}) where {T,N}
38-
return DiagonalArray{T,N}(diag, d)
53+
function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}; kwargs...) where {T,N}
54+
return DiagonalArray{T,N}(diag, dims; kwargs...)
3955
end
4056

41-
function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N}
42-
return DiagonalArray{T,N}(diag, d)
57+
function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs...) where {T,N}
58+
return DiagonalArray{T,N}(diag, dims; kwargs...)
4359
end
4460

4561
# Infer size from diagonal
46-
function DiagonalArray{T,N}(diag::AbstractVector) where {T,N}
47-
return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N))
62+
function DiagonalArray{T,N}(diag::AbstractVector; kwargs...) where {T,N}
63+
return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N); kwargs...)
4864
end
4965

50-
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}) where {T,N}
51-
return DiagonalArray{T,N}(diag)
66+
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}; kwargs...) where {T,N}
67+
return DiagonalArray{T,N}(diag; kwargs...)
5268
end
5369

5470
# undef
55-
function DiagonalArray{T,N}(
56-
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero
57-
) where {T,N}
58-
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, getunstoredindex)
71+
function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N}
72+
return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims; kwargs...)
5973
end
6074

61-
function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
62-
return DiagonalArray{T,N}(undef, d)
75+
function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}; kwargs...) where {T,N}
76+
return DiagonalArray{T,N}(undef, dims; kwargs...)
6377
end
6478

65-
function DiagonalArray{T}(
66-
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=getzero
67-
) where {T,N}
68-
return DiagonalArray{T,N}(undef, d, getunstoredindex)
79+
function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N}
80+
return DiagonalArray{T,N}(undef, dims; kwargs...)
81+
end
82+
83+
function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
84+
return DiagonalArray{T,N}(undef, dims)
6985
end
7086

7187
# Axes version
7288
function DiagonalArray{T}(
73-
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}, getunstoredindex=getzero
89+
::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}; kwargs...
7490
) where {T,N}
75-
@assert all(isone, first.(axes))
76-
return DiagonalArray{T,N}(undef, length.(axes), getunstoredindex)
77-
end
78-
79-
function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
80-
return DiagonalArray{T,N}(undef, d)
91+
return DiagonalArray{T,N}(undef, length.(axes); kwargs...)
8192
end
8293

8394
# Minimal `AbstractArray` interface
8495
Base.size(a::DiagonalArray) = a.dims
8596

8697
function Base.similar(a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}})
87-
# TODO: Preserve zero element function.
88-
return DiagonalArray{elt}(undef, dims, a.getunstoredindex)
98+
function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
99+
return convert(elt, a.getunstored(a, I...))
100+
end
101+
return DiagonalArray{elt}(undef, dims; getunstored=getzero)
89102
end
90103

91104
# DiagonalArrays interface.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1010

1111
[compat]
1212
Aqua = "0.8.9"
13-
DiagonalArrays = "0.2"
13+
DiagonalArrays = "0.3"
1414
FillArrays = "1"
1515
LinearAlgebra = "1"
1616
SafeTestsets = "0.1"

0 commit comments

Comments
 (0)