Skip to content

WIP new @SA macro and comprehension support #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,112 @@ end
function promote_rule(::Type{<:SArray{S,T,N,L}}, ::Type{<:SArray{S,U,N,L}}) where {S,T,U,N,L}
SArray{S,promote_type(T,U),N,L}
end


macro SA(ex)
if !isa(ex, Expr)
error("Bad input for @SA")
end

if ex.head == :vect # vector
return esc(Expr(:call, SArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...)))
elseif ex.head == :ref # typed, vector
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif ex.head == :hcat # 1 x n
s1 = 1
s2 = length(ex.args)
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...)))
elseif ex.head == :typed_hcat # typed, 1 x n
s1 = 1
s2 = length(ex.args) - 1
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif ex.head == :vcat
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
# Validate
s1 = length(ex.args)
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
error("Rows must be of matching lengths")
end

exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, exprs...)))
else # n x 1
return esc(Expr(:call, SArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal vcat returns a Vector rather than an n x 1 matrix.

end
elseif ex.head == :typed_vcat
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
# Validate
s1 = length(ex.args) - 1
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
error("Rows must be of matching lengths")
end

exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...)))
else # typed, n x 1
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
end
elseif isa(ex, Expr) && ex.head == :comprehension
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
error("Expected generator in comprehension, e.g. [f(i,j) for i in 1:3, j in 1:3]")
end

ex = ex.args[1]
n_ranges = length(ex.args) - 1
range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges]
ranges = [ex.args[i+1].args[2] for i = 1:n_ranges]
func = :(($(Expr(:tuple, range_vars...)) -> $(ex.args[1])))

return quote
$(Expr(:call, :sarray_comprehension, func, ranges...))
end
elseif isa(ex, Expr) && ex.head == :typed_comprehension
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
end
T = ex.args[1]
ex = ex.args[2]
n_ranges = length(ex.args) - 1
range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges]
ranges = [ex.args[i+1].args[2] for i = 1:n_ranges]
func = esc(:(($(Expr(:tuple, range_vars...)) -> $(ex.args[1]))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we might be able to remove some of this logic and just pass the generator along to typed_sarray_comprehension. Ie, make this lowering less about the syntax, and more about just lowering to something other than typed_comprehension which is mangled by Base lowering...


return quote
$(Expr(:call, :typed_sarray_comprehension, T, func, ranges...))
end
else
error("Bad input for @SA")
end
end

@inline function typed_sarray_comprehension(::Type{T}, f, range) where {T}
L = length(range)
_typed_sarray_comprehension(SVector{L,T}, f, range)
end

@inline function typed_sarray_comprehension(::Type{T}, f, range1, range2) where {T}
S1 = length(range1)
S2 = length(range2)
L = S1*S2
_typed_sarray_comprehension(SArray{Tuple{S1,S2},T,2,L}, f, range1,range2)
end

@generated function _typed_sarray_comprehension(::Type{SArray{S,T,N,L}}, f, range) where {S, T, N, L}
exprs = [:(f(range[$i])) for i in 1:L]
return quote
@_inline_meta
$(Expr(:call, SArray{S, T, N, L}, Expr(:tuple, exprs...)))
end
end

@generated function _typed_sarray_comprehension(::Type{SArray{Tuple{S1,S2},T,N,L}}, f, range1, range2) where {S1, S2, T, N, L}
exprs = [:(f(range1[$i1], range2[$i2])) for i1 in 1:S1, i2 in 1:S2]
return quote
@_inline_meta
$(Expr(:call, SArray{Tuple{S1,S2}, T, N, L}, Expr(:tuple, exprs...)))
end
end
2 changes: 2 additions & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export SHermitianCompact

export Size, Length

export SA
export @SVector, @SMatrix, @SArray
export @MVector, @MMatrix, @MArray

Expand Down Expand Up @@ -113,6 +114,7 @@ include("SizedArray.jl")
include("SDiagonal.jl")
include("SHermitianCompact.jl")

include("initializers.jl")
include("convert.jl")

include("abstractarray.jl")
Expand Down
38 changes: 38 additions & 0 deletions src/initializers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
SA[ array initializer ]

A type for initializing static array literals using array construction syntax.
Returns an `SVector` or `SMatrix`.

# Examples:

* `SA[x, y]` creates a length-2 SVector
* `SA[a b; c d]` creates a 2×2 SMatrix
* `SA[a b]` creates a 1×2 SMatrix
"""
struct SA ; end

Base.getindex(::Type{SA}, xs...) = SVector(xs)
Base.typed_vcat(::Type{SA}, xs::Number...) = SVector(xs)
Base.typed_hcat(::Type{SA}, xs::Number...) = SMatrix{1,length(xs)}(xs)

Base.@pure function _SA_hvcat_transposed_type(rows)
M = rows[1]
if any(r->r != M, rows)
# @pure may not throw... probably. See
# https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459
return nothing
end
SMatrix{M,length(rows)}
end

@inline function Base.typed_hvcat(::Type{SA}, rows::Dims, xs::Number...)
mtype = _SA_hvcat_transposed_type(rows)
if mtype === nothing
throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent"))
end
# hvcat lowering is row major ordering, so must transpose
transpose(mtype(xs))
end


24 changes: 24 additions & 0 deletions test/initializers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
SA_test_ref(x) = SA[1,x,x]
@test @inferred(SA_test_ref(2)) === SVector{3,Int}((1,2,2))
@test @inferred(SA_test_ref(2.0)) === SVector{3,Float64}((1,2,2))

SA_test_vcat(x) = SA[1;x;x]
@test @inferred(SA_test_vcat(2)) === SVector{3,Int}((1,2,2))
@test @inferred(SA_test_vcat(2.0)) === SVector{3,Float64}((1,2,2))

SA_test_hcat(x) = SA[1 x x]
@test @inferred(SA_test_hcat(2)) === SMatrix{1,3,Int}((1,2,2))
@test @inferred(SA_test_hcat(2.0)) === SMatrix{1,3,Float64}((1,2,2))

SA_test_hvcat(x) = SA[1 x x;
x 2 x]
@test @inferred(SA_test_hvcat(3)) === SMatrix{2,3,Int}((1,3,3,2,3,3))
@test @inferred(SA_test_hvcat(3.0)) === SMatrix{2,3,Float64}((1,3,3,2,3,3))
@test @inferred(SA_test_hvcat(1.0im)) === SMatrix{2,3,ComplexF64}((1,1im,1im,2,1im,1im))

@test SA[1] === SVector{1,Int}((1))

@test_throws ArgumentError("SA[...] matrix rows of length (3, 2) are inconsistent") SA[1 2 3;
4 5]
@test_throws ArgumentError("SA[...] matrix rows of length (2, 3) are inconsistent") SA[1 2;
3 4 5]
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include("convert.jl")
include("core.jl")
include("abstractarray.jl")
include("indexing.jl")
include("initializers.jl")
Random.seed!(42); include("mapreduce.jl")
Random.seed!(42); include("arraymath.jl")
include("broadcast.jl")
Expand Down