diff --git a/src/SArray.jl b/src/SArray.jl index 767dad8b..6c6c8c96 100644 --- a/src/SArray.jl +++ b/src/SArray.jl @@ -237,3 +237,112 @@ end function promote_rule(::Type{<:SArray{S,T,N,L}}, ::Type{<:SArray{S,U,N,L}}) where {S,T,U,N,L} SArray{S,promote_type(T,U),N,L} end + + +macro SA(ex) + if !isa(ex, Expr) + error("Bad input for @SA") + end + + if ex.head == :vect # vector + return esc(Expr(:call, SArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...))) + elseif ex.head == :ref # typed, vector + return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) + elseif ex.head == :hcat # 1 x n + s1 = 1 + s2 = length(ex.args) + return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...))) + elseif ex.head == :typed_hcat # typed, 1 x n + s1 = 1 + s2 = length(ex.args) - 1 + return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) + elseif ex.head == :vcat + if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m + # Validate + s1 = length(ex.args) + s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1) + s2 = minimum(s2s) + if maximum(s2s) != s2 + error("Rows must be of matching lengths") + end + + exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2] + return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, exprs...))) + else # n x 1 + return esc(Expr(:call, SArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...))) + end + elseif ex.head == :typed_vcat + if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m + # Validate + s1 = length(ex.args) - 1 + s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1) + s2 = minimum(s2s) + if maximum(s2s) != s2 + error("Rows must be of matching lengths") + end + + exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2] + return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...))) + else # typed, n x 1 + return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) + end + elseif isa(ex, Expr) && ex.head == :comprehension + if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator + error("Expected generator in comprehension, e.g. [f(i,j) for i in 1:3, j in 1:3]") + end + + ex = ex.args[1] + n_ranges = length(ex.args) - 1 + range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges] + ranges = [ex.args[i+1].args[2] for i = 1:n_ranges] + func = :(($(Expr(:tuple, range_vars...)) -> $(ex.args[1]))) + + return quote + $(Expr(:call, :sarray_comprehension, func, ranges...)) + end + elseif isa(ex, Expr) && ex.head == :typed_comprehension + if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator + error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") + end + T = ex.args[1] + ex = ex.args[2] + n_ranges = length(ex.args) - 1 + range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges] + ranges = [ex.args[i+1].args[2] for i = 1:n_ranges] + func = esc(:(($(Expr(:tuple, range_vars...)) -> $(ex.args[1])))) + + return quote + $(Expr(:call, :typed_sarray_comprehension, T, func, ranges...)) + end + else + error("Bad input for @SA") + end +end + +@inline function typed_sarray_comprehension(::Type{T}, f, range) where {T} + L = length(range) + _typed_sarray_comprehension(SVector{L,T}, f, range) +end + +@inline function typed_sarray_comprehension(::Type{T}, f, range1, range2) where {T} + S1 = length(range1) + S2 = length(range2) + L = S1*S2 + _typed_sarray_comprehension(SArray{Tuple{S1,S2},T,2,L}, f, range1,range2) +end + +@generated function _typed_sarray_comprehension(::Type{SArray{S,T,N,L}}, f, range) where {S, T, N, L} + exprs = [:(f(range[$i])) for i in 1:L] + return quote + @_inline_meta + $(Expr(:call, SArray{S, T, N, L}, Expr(:tuple, exprs...))) + end +end + +@generated function _typed_sarray_comprehension(::Type{SArray{Tuple{S1,S2},T,N,L}}, f, range1, range2) where {S1, S2, T, N, L} + exprs = [:(f(range1[$i1], range2[$i2])) for i1 in 1:S1, i2 in 1:S2] + return quote + @_inline_meta + $(Expr(:call, SArray{Tuple{S1,S2}, T, N, L}, Expr(:tuple, exprs...))) + end +end \ No newline at end of file diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 07c47438..54d97807 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -36,6 +36,7 @@ export SHermitianCompact export Size, Length +export SA export @SVector, @SMatrix, @SArray export @MVector, @MMatrix, @MArray @@ -113,6 +114,7 @@ include("SizedArray.jl") include("SDiagonal.jl") include("SHermitianCompact.jl") +include("initializers.jl") include("convert.jl") include("abstractarray.jl") diff --git a/src/initializers.jl b/src/initializers.jl new file mode 100644 index 00000000..81abf879 --- /dev/null +++ b/src/initializers.jl @@ -0,0 +1,38 @@ +""" + SA[ array initializer ] + +A type for initializing static array literals using array construction syntax. +Returns an `SVector` or `SMatrix`. + +# Examples: + +* `SA[x, y]` creates a length-2 SVector +* `SA[a b; c d]` creates a 2×2 SMatrix +* `SA[a b]` creates a 1×2 SMatrix +""" +struct SA ; end + +Base.getindex(::Type{SA}, xs...) = SVector(xs) +Base.typed_vcat(::Type{SA}, xs::Number...) = SVector(xs) +Base.typed_hcat(::Type{SA}, xs::Number...) = SMatrix{1,length(xs)}(xs) + +Base.@pure function _SA_hvcat_transposed_type(rows) + M = rows[1] + if any(r->r != M, rows) + # @pure may not throw... probably. See + # https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459 + return nothing + end + SMatrix{M,length(rows)} +end + +@inline function Base.typed_hvcat(::Type{SA}, rows::Dims, xs::Number...) + mtype = _SA_hvcat_transposed_type(rows) + if mtype === nothing + throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent")) + end + # hvcat lowering is row major ordering, so must transpose + transpose(mtype(xs)) +end + + diff --git a/test/initializers.jl b/test/initializers.jl new file mode 100644 index 00000000..20643432 --- /dev/null +++ b/test/initializers.jl @@ -0,0 +1,24 @@ +SA_test_ref(x) = SA[1,x,x] +@test @inferred(SA_test_ref(2)) === SVector{3,Int}((1,2,2)) +@test @inferred(SA_test_ref(2.0)) === SVector{3,Float64}((1,2,2)) + +SA_test_vcat(x) = SA[1;x;x] +@test @inferred(SA_test_vcat(2)) === SVector{3,Int}((1,2,2)) +@test @inferred(SA_test_vcat(2.0)) === SVector{3,Float64}((1,2,2)) + +SA_test_hcat(x) = SA[1 x x] +@test @inferred(SA_test_hcat(2)) === SMatrix{1,3,Int}((1,2,2)) +@test @inferred(SA_test_hcat(2.0)) === SMatrix{1,3,Float64}((1,2,2)) + +SA_test_hvcat(x) = SA[1 x x; + x 2 x] +@test @inferred(SA_test_hvcat(3)) === SMatrix{2,3,Int}((1,3,3,2,3,3)) +@test @inferred(SA_test_hvcat(3.0)) === SMatrix{2,3,Float64}((1,3,3,2,3,3)) +@test @inferred(SA_test_hvcat(1.0im)) === SMatrix{2,3,ComplexF64}((1,1im,1im,2,1im,1im)) + +@test SA[1] === SVector{1,Int}((1)) + +@test_throws ArgumentError("SA[...] matrix rows of length (3, 2) are inconsistent") SA[1 2 3; + 4 5] +@test_throws ArgumentError("SA[...] matrix rows of length (2, 3) are inconsistent") SA[1 2; + 3 4 5] diff --git a/test/runtests.jl b/test/runtests.jl index 279cb7ad..339ee1f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("convert.jl") include("core.jl") include("abstractarray.jl") include("indexing.jl") +include("initializers.jl") Random.seed!(42); include("mapreduce.jl") Random.seed!(42); include("arraymath.jl") include("broadcast.jl")