@@ -2,90 +2,103 @@ function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2
2
return zero (eltype (a))
3
3
end
4
4
5
+ function _DiagonalArray end
6
+
5
7
struct DiagonalArray{T,N,Diag<: AbstractVector{T} ,F} <: AbstractDiagonalArray{T,N}
6
8
diag:: Diag
7
- dims:: NTuple{N,Int}
8
- getunstoredindex:: F
9
+ dims:: Dims{N}
10
+ getunstored:: F
11
+ global @inline function _DiagonalArray (
12
+ diag:: Diag , dims:: Dims{N} , getunstored:: F
13
+ ) where {T,N,Diag<: AbstractVector{T} ,F}
14
+ all (≥ (0 ), dims) || throw (ArgumentError (" Invalid dimensions: $dims " ))
15
+ length (diag) == minimum (dims) ||
16
+ throw (ArgumentError (" Length of diagonals doesn't match dimensions" ))
17
+ return new {T,N,Diag,F} (diag, dims, getunstored)
18
+ end
9
19
end
10
20
11
21
function DiagonalArray {T,N} (
12
- diag:: AbstractVector{T} , d :: Tuple{Vararg{Int,N}} , getunstoredindex = getzero
22
+ diag:: AbstractVector , dims :: Dims{N} ; getunstored = getzero
13
23
) where {T,N}
14
- return DiagonalArray {T,N,typeof( diag),typeof(getunstoredindex)} (diag, d, getunstoredindex )
24
+ return _DiagonalArray ( convert (AbstractVector{T}, diag), dims, getunstored )
15
25
end
16
26
17
27
function DiagonalArray {T,N} (
18
- diag:: AbstractVector , d :: Tuple{ Vararg{Int,N}} , getunstoredindex = getzero
28
+ diag:: AbstractVector , dims :: Vararg{Int,N} ; kwargs ...
19
29
) where {T,N}
20
- return DiagonalArray {T,N} (T .( diag), d, getunstoredindex )
30
+ return DiagonalArray {T,N} (diag, dims; kwargs ... )
21
31
end
22
32
23
- function DiagonalArray {T,N } (diag:: AbstractVector , d :: Vararg{Int,N} ) where {T,N}
24
- return DiagonalArray {T,N} (diag, d )
33
+ function DiagonalArray {T} (diag:: AbstractVector , dims :: Dims{N} ; kwargs ... ) where {T,N}
34
+ return DiagonalArray {T,N} (diag, dims; kwargs ... )
25
35
end
26
36
27
- function DiagonalArray {T} (
28
- diag:: AbstractVector , d:: Tuple{Vararg{Int,N}} , getunstoredindex= getzero
37
+ function DiagonalArray {T} (diag:: AbstractVector , dims:: Vararg{Int,N} ; kwargs... ) where {T,N}
38
+ return DiagonalArray {T,N} (diag, dims; kwargs... )
39
+ end
40
+
41
+ function DiagonalArray {<:Any,N} (
42
+ diag:: AbstractVector{T} , dims:: Dims{N} ; kwargs...
29
43
) where {T,N}
30
- return DiagonalArray {T,N} (diag, d, getunstoredindex )
44
+ return DiagonalArray {T,N} (diag, dims; kwargs ... )
31
45
end
32
46
33
- function DiagonalArray {T} (diag:: AbstractVector , d:: Vararg{Int,N} ) where {T,N}
34
- return DiagonalArray {T,N} (diag, d)
47
+ function DiagonalArray {<:Any,N} (
48
+ diag:: AbstractVector{T} , dims:: Vararg{Int,N} ; kwargs...
49
+ ) where {T,N}
50
+ return DiagonalArray {T,N} (diag, dims; kwargs... )
35
51
end
36
52
37
- function DiagonalArray (diag:: AbstractVector{T} , d :: Tuple{Vararg{Int,N}} ) where {T,N}
38
- return DiagonalArray {T,N} (diag, d )
53
+ function DiagonalArray (diag:: AbstractVector{T} , dims :: Dims{N} ; kwargs ... ) where {T,N}
54
+ return DiagonalArray {T,N} (diag, dims; kwargs ... )
39
55
end
40
56
41
- function DiagonalArray (diag:: AbstractVector{T} , d :: Vararg{Int,N} ) where {T,N}
42
- return DiagonalArray {T,N} (diag, d )
57
+ function DiagonalArray (diag:: AbstractVector{T} , dims :: Vararg{Int,N} ; kwargs ... ) where {T,N}
58
+ return DiagonalArray {T,N} (diag, dims; kwargs ... )
43
59
end
44
60
45
61
# Infer size from diagonal
46
- function DiagonalArray {T,N} (diag:: AbstractVector ) where {T,N}
47
- return DiagonalArray {T,N} (diag, ntuple (Returns (length (diag)), N))
62
+ function DiagonalArray {T,N} (diag:: AbstractVector ; kwargs ... ) where {T,N}
63
+ return DiagonalArray {T,N} (diag, ntuple (Returns (length (diag)), N); kwargs ... )
48
64
end
49
65
50
- function DiagonalArray {<:Any,N} (diag:: AbstractVector{T} ) where {T,N}
51
- return DiagonalArray {T,N} (diag)
66
+ function DiagonalArray {<:Any,N} (diag:: AbstractVector{T} ; kwargs ... ) where {T,N}
67
+ return DiagonalArray {T,N} (diag; kwargs ... )
52
68
end
53
69
54
70
# undef
55
- function DiagonalArray {T,N} (
56
- :: UndefInitializer , d:: Tuple{Vararg{Int,N}} , getunstoredindex= getzero
57
- ) where {T,N}
58
- return DiagonalArray {T,N} (Vector {T} (undef, minimum (d)), d, getunstoredindex)
71
+ function DiagonalArray {T,N} (:: UndefInitializer , dims:: Dims{N} ; kwargs... ) where {T,N}
72
+ return DiagonalArray {T,N} (Vector {T} (undef, minimum (dims)), dims; kwargs... )
59
73
end
60
74
61
- function DiagonalArray {T,N} (:: UndefInitializer , d :: Vararg{Int,N} ) where {T,N}
62
- return DiagonalArray {T,N} (undef, d )
75
+ function DiagonalArray {T,N} (:: UndefInitializer , dims :: Vararg{Int,N} ; kwargs ... ) where {T,N}
76
+ return DiagonalArray {T,N} (undef, dims; kwargs ... )
63
77
end
64
78
65
- function DiagonalArray {T} (
66
- :: UndefInitializer , d:: Tuple{Vararg{Int,N}} , getunstoredindex= getzero
67
- ) where {T,N}
68
- return DiagonalArray {T,N} (undef, d, getunstoredindex)
79
+ function DiagonalArray {T} (:: UndefInitializer , dims:: Dims{N} ; kwargs... ) where {T,N}
80
+ return DiagonalArray {T,N} (undef, dims; kwargs... )
81
+ end
82
+
83
+ function DiagonalArray {T} (:: UndefInitializer , dims:: Vararg{Int,N} ) where {T,N}
84
+ return DiagonalArray {T,N} (undef, dims)
69
85
end
70
86
71
87
# Axes version
72
88
function DiagonalArray {T} (
73
- :: UndefInitializer , axes:: Tuple{Vararg{AbstractUnitRange,N}} , getunstoredindex = getzero
89
+ :: UndefInitializer , axes:: NTuple{N,Base.OneTo{Int}} ; kwargs ...
74
90
) where {T,N}
75
- @assert all (isone, first .(axes))
76
- return DiagonalArray {T,N} (undef, length .(axes), getunstoredindex)
77
- end
78
-
79
- function DiagonalArray {T} (:: UndefInitializer , d:: Vararg{Int,N} ) where {T,N}
80
- return DiagonalArray {T,N} (undef, d)
91
+ return DiagonalArray {T,N} (undef, length .(axes); kwargs... )
81
92
end
82
93
83
94
# Minimal `AbstractArray` interface
84
95
Base. size (a:: DiagonalArray ) = a. dims
85
96
86
97
function Base. similar (a:: DiagonalArray , elt:: Type , dims:: Tuple{Vararg{Int}} )
87
- # TODO : Preserve zero element function.
88
- return DiagonalArray {elt} (undef, dims, a. getunstoredindex)
98
+ function getzero (a:: AbstractArray{<:Any,N} , I:: Vararg{Int,N} ) where {N}
99
+ return convert (elt, a. getunstored (a, I... ))
100
+ end
101
+ return DiagonalArray {elt} (undef, dims; getunstored= getzero)
89
102
end
90
103
91
104
# DiagonalArrays interface.
0 commit comments