diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 07c47438..54d97807 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -36,6 +36,7 @@ export SHermitianCompact export Size, Length +export SA export @SVector, @SMatrix, @SArray export @MVector, @MMatrix, @MArray @@ -113,6 +114,7 @@ include("SizedArray.jl") include("SDiagonal.jl") include("SHermitianCompact.jl") +include("initializers.jl") include("convert.jl") include("abstractarray.jl") diff --git a/src/initializers.jl b/src/initializers.jl new file mode 100644 index 00000000..2eb471c1 --- /dev/null +++ b/src/initializers.jl @@ -0,0 +1,46 @@ +""" + SA[ elements ] + SA{T}[ elements ] + +Create `SArray` literals using array construction syntax. The element type is +inferred by promoting `elements` to a common type or set to `T` when `T` is +provided explicitly. + +# Examples: + +* `SA[1.0, 2.0]` creates a length-2 `SVector` of `Float64` elements. +* `SA[1 2; 3 4]` creates a 2×2 SMatrix of `Int`s. +* `SA[1 2]` creates a 1×2 SMatrix of `Int`s. +* `SA{Float32}[1, 2]` creates a length-2 `SVector` of `Float32` elements. +""" +struct SA{T} ; end + +@inline similar_type(::Type{SA}, ::Size{S}) where {S} = SArray{Tuple{S...}} +@inline similar_type(::Type{SA{T}}, ::Size{S}) where {T,S} = SArray{Tuple{S...}, T} + +Base.@pure _SA_type(sa::Type{SA}, len::Int) = SVector{len} +Base.@pure _SA_type(sa::Type{SA{T}}, len::Int) where {T} = SVector{len,T} + +@inline Base.getindex(sa::Type{<:SA}, xs...) where T = similar_type(sa, Size(length(xs)))(xs) +@inline Base.typed_vcat(sa::Type{<:SA}, xs::Number...) where T = similar_type(sa, Size(length(xs)))(xs) +@inline Base.typed_hcat(sa::Type{<:SA}, xs::Number...) where T = similar_type(sa, Size(1,length(xs)))(xs) + +Base.@pure function _SA_hvcat_transposed_size(rows) + M = rows[1] + if any(r->r != M, rows) + # @pure may not throw... probably. See + # https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459 + return nothing + end + Size(M, length(rows)) +end + +@inline function Base.typed_hvcat(sa::Type{<:SA}, rows::Dims, xs::Number...) where T + msize = _SA_hvcat_transposed_size(rows) + if msize === nothing + throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent")) + end + # hvcat lowering is row major ordering, so we must transpose + transpose(similar_type(sa, msize)(xs)) +end + diff --git a/test/initializers.jl b/test/initializers.jl new file mode 100644 index 00000000..4d37ce64 --- /dev/null +++ b/test/initializers.jl @@ -0,0 +1,33 @@ +SA_test_ref(x) = SA[1,x,x] +SA_test_ref(x,T) = SA{T}[1,x,x] +@test @inferred(SA_test_ref(2)) === SVector{3,Int}((1,2,2)) +@test @inferred(SA_test_ref(2.0)) === SVector{3,Float64}((1,2,2)) +@test @inferred(SA_test_ref(2,Float32)) === SVector{3,Float32}((1,2,2)) + +SA_test_vcat(x) = SA[1;x;x] +SA_test_vcat(x,T) = SA{T}[1;x;x] +@test @inferred(SA_test_vcat(2)) === SVector{3,Int}((1,2,2)) +@test @inferred(SA_test_vcat(2.0)) === SVector{3,Float64}((1,2,2)) +@test @inferred(SA_test_vcat(2,Float32)) === SVector{3,Float32}((1,2,2)) + +SA_test_hcat(x) = SA[1 x x] +SA_test_hcat(x,T) = SA{T}[1 x x] +@test @inferred(SA_test_hcat(2)) === SMatrix{1,3,Int}((1,2,2)) +@test @inferred(SA_test_hcat(2.0)) === SMatrix{1,3,Float64}((1,2,2)) +@test @inferred(SA_test_hcat(2,Float32)) === SMatrix{1,3,Float32}((1,2,2)) + +SA_test_hvcat(x) = SA[1 x x; + x 2 x] +SA_test_hvcat(x,T) = SA{T}[1 x x; + x 2 x] +@test @inferred(SA_test_hvcat(3)) === SMatrix{2,3,Int}((1,3,3,2,3,3)) +@test @inferred(SA_test_hvcat(3.0)) === SMatrix{2,3,Float64}((1,3,3,2,3,3)) +@test @inferred(SA_test_hvcat(1.0im)) === SMatrix{2,3,ComplexF64}((1,1im,1im,2,1im,1im)) +@test @inferred(SA_test_hvcat(3,Float32)) === SMatrix{2,3,Float32}((1,3,3,2,3,3)) + +@test SA[1] === SVector{1,Int}((1)) + +@test_throws ArgumentError("SA[...] matrix rows of length (3, 2) are inconsistent") SA[1 2 3; + 4 5] +@test_throws ArgumentError("SA[...] matrix rows of length (2, 3) are inconsistent") SA[1 2; + 3 4 5] diff --git a/test/runtests.jl b/test/runtests.jl index 279cb7ad..339ee1f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("convert.jl") include("core.jl") include("abstractarray.jl") include("indexing.jl") +include("initializers.jl") Random.seed!(42); include("mapreduce.jl") Random.seed!(42); include("arraymath.jl") include("broadcast.jl")