-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathArrayInterfaceStaticArraysCoreExt.jl
44 lines (35 loc) · 1.51 KB
/
ArrayInterfaceStaticArraysCoreExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
module ArrayInterfaceStaticArraysCoreExt
if isdefined(Base, :get_extension)
import ArrayInterface
using LinearAlgebra
import StaticArraysCore
else
import ..ArrayInterface
using ..LinearAlgebra
import ..StaticArraysCore
end
function ArrayInterface.undefmatrix(::StaticArraysCore.MArray{S, T, N, L}) where {S, T, N, L}
return StaticArraysCore.MMatrix{L, L, T, L*L}(undef)
end
# SArray doesn't have an undef constructor and is going to be small enough that this is fine.
function ArrayInterface.undefmatrix(s::StaticArraysCore.SArray)
v = vec(s)
return v.*v'
end
ArrayInterface.ismutable(::Type{<:StaticArraysCore.StaticArray}) = false
ArrayInterface.ismutable(::Type{<:StaticArraysCore.MArray}) = true
ArrayInterface.ismutable(::Type{<:StaticArraysCore.SizedArray}) = true
ArrayInterface.can_setindex(::Type{<:StaticArraysCore.StaticArray}) = false
ArrayInterface.buffer(A::Union{StaticArraysCore.SArray,StaticArraysCore.MArray}) = getfield(A, :data)
function ArrayInterface.lu_instance(_A::StaticArraysCore.StaticMatrix{N,N}) where {N}
lu(one(_A))
end
ArrayInterface.restructure(x::StaticArraysCore.SArray{S}, y) where {S} = StaticArraysCore.SArray{S}(y)
function ArrayInterface.known_size(::Type{<:StaticArraysCore.StaticArray{S}}) where {S}
@isdefined(S) ? tuple(S.parameters...) : ntuple(_-> nothing, ndims(T))
end
function ArrayInterface.known_length(T::Type{<:StaticArraysCore.StaticArray})
sz = ArrayInterface.known_size(T)
isa(sz, Tuple{Vararg{Nothing}}) ? nothing : prod(sz)
end
end