-
Notifications
You must be signed in to change notification settings - Fork 152
WIP new @SA
macro and comprehension support
#636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,3 +237,112 @@ end | |
function promote_rule(::Type{<:SArray{S,T,N,L}}, ::Type{<:SArray{S,U,N,L}}) where {S,T,U,N,L} | ||
SArray{S,promote_type(T,U),N,L} | ||
end | ||
|
||
|
||
macro SA(ex) | ||
if !isa(ex, Expr) | ||
error("Bad input for @SA") | ||
end | ||
|
||
if ex.head == :vect # vector | ||
return esc(Expr(:call, SArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...))) | ||
elseif ex.head == :ref # typed, vector | ||
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | ||
elseif ex.head == :hcat # 1 x n | ||
s1 = 1 | ||
s2 = length(ex.args) | ||
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...))) | ||
elseif ex.head == :typed_hcat # typed, 1 x n | ||
s1 = 1 | ||
s2 = length(ex.args) - 1 | ||
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | ||
elseif ex.head == :vcat | ||
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m | ||
# Validate | ||
s1 = length(ex.args) | ||
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1) | ||
s2 = minimum(s2s) | ||
if maximum(s2s) != s2 | ||
error("Rows must be of matching lengths") | ||
end | ||
|
||
exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2] | ||
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, exprs...))) | ||
else # n x 1 | ||
return esc(Expr(:call, SArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...))) | ||
end | ||
elseif ex.head == :typed_vcat | ||
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m | ||
# Validate | ||
s1 = length(ex.args) - 1 | ||
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1) | ||
s2 = minimum(s2s) | ||
if maximum(s2s) != s2 | ||
error("Rows must be of matching lengths") | ||
end | ||
|
||
exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2] | ||
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...))) | ||
else # typed, n x 1 | ||
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | ||
end | ||
elseif isa(ex, Expr) && ex.head == :comprehension | ||
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator | ||
error("Expected generator in comprehension, e.g. [f(i,j) for i in 1:3, j in 1:3]") | ||
end | ||
|
||
ex = ex.args[1] | ||
n_ranges = length(ex.args) - 1 | ||
range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges] | ||
ranges = [ex.args[i+1].args[2] for i = 1:n_ranges] | ||
func = :(($(Expr(:tuple, range_vars...)) -> $(ex.args[1]))) | ||
|
||
return quote | ||
$(Expr(:call, :sarray_comprehension, func, ranges...)) | ||
end | ||
elseif isa(ex, Expr) && ex.head == :typed_comprehension | ||
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator | ||
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") | ||
end | ||
T = ex.args[1] | ||
ex = ex.args[2] | ||
n_ranges = length(ex.args) - 1 | ||
range_vars = [ex.args[i+1].args[1] for i = 1:n_ranges] | ||
ranges = [ex.args[i+1].args[2] for i = 1:n_ranges] | ||
func = esc(:(($(Expr(:tuple, range_vars...)) -> $(ex.args[1])))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we might be able to remove some of this logic and just pass the generator along to |
||
|
||
return quote | ||
$(Expr(:call, :typed_sarray_comprehension, T, func, ranges...)) | ||
end | ||
else | ||
error("Bad input for @SA") | ||
end | ||
end | ||
|
||
@inline function typed_sarray_comprehension(::Type{T}, f, range) where {T} | ||
L = length(range) | ||
_typed_sarray_comprehension(SVector{L,T}, f, range) | ||
end | ||
|
||
@inline function typed_sarray_comprehension(::Type{T}, f, range1, range2) where {T} | ||
S1 = length(range1) | ||
S2 = length(range2) | ||
L = S1*S2 | ||
_typed_sarray_comprehension(SArray{Tuple{S1,S2},T,2,L}, f, range1,range2) | ||
end | ||
|
||
@generated function _typed_sarray_comprehension(::Type{SArray{S,T,N,L}}, f, range) where {S, T, N, L} | ||
exprs = [:(f(range[$i])) for i in 1:L] | ||
return quote | ||
@_inline_meta | ||
$(Expr(:call, SArray{S, T, N, L}, Expr(:tuple, exprs...))) | ||
end | ||
end | ||
|
||
@generated function _typed_sarray_comprehension(::Type{SArray{Tuple{S1,S2},T,N,L}}, f, range1, range2) where {S1, S2, T, N, L} | ||
exprs = [:(f(range1[$i1], range2[$i2])) for i1 in 1:S1, i2 in 1:S2] | ||
return quote | ||
@_inline_meta | ||
$(Expr(:call, SArray{Tuple{S1,S2}, T, N, L}, Expr(:tuple, exprs...))) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
SA[ array initializer ] | ||
|
||
A type for initializing static array literals using array construction syntax. | ||
Returns an `SVector` or `SMatrix`. | ||
|
||
# Examples: | ||
|
||
* `SA[x, y]` creates a length-2 SVector | ||
* `SA[a b; c d]` creates a 2×2 SMatrix | ||
* `SA[a b]` creates a 1×2 SMatrix | ||
""" | ||
struct SA ; end | ||
|
||
Base.getindex(::Type{SA}, xs...) = SVector(xs) | ||
Base.typed_vcat(::Type{SA}, xs::Number...) = SVector(xs) | ||
Base.typed_hcat(::Type{SA}, xs::Number...) = SMatrix{1,length(xs)}(xs) | ||
|
||
Base.@pure function _SA_hvcat_transposed_type(rows) | ||
M = rows[1] | ||
if any(r->r != M, rows) | ||
# @pure may not throw... probably. See | ||
# https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459 | ||
return nothing | ||
end | ||
SMatrix{M,length(rows)} | ||
end | ||
|
||
@inline function Base.typed_hvcat(::Type{SA}, rows::Dims, xs::Number...) | ||
mtype = _SA_hvcat_transposed_type(rows) | ||
if mtype === nothing | ||
throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent")) | ||
end | ||
# hvcat lowering is row major ordering, so must transpose | ||
transpose(mtype(xs)) | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
SA_test_ref(x) = SA[1,x,x] | ||
@test @inferred(SA_test_ref(2)) === SVector{3,Int}((1,2,2)) | ||
@test @inferred(SA_test_ref(2.0)) === SVector{3,Float64}((1,2,2)) | ||
|
||
SA_test_vcat(x) = SA[1;x;x] | ||
@test @inferred(SA_test_vcat(2)) === SVector{3,Int}((1,2,2)) | ||
@test @inferred(SA_test_vcat(2.0)) === SVector{3,Float64}((1,2,2)) | ||
|
||
SA_test_hcat(x) = SA[1 x x] | ||
@test @inferred(SA_test_hcat(2)) === SMatrix{1,3,Int}((1,2,2)) | ||
@test @inferred(SA_test_hcat(2.0)) === SMatrix{1,3,Float64}((1,2,2)) | ||
|
||
SA_test_hvcat(x) = SA[1 x x; | ||
x 2 x] | ||
@test @inferred(SA_test_hvcat(3)) === SMatrix{2,3,Int}((1,3,3,2,3,3)) | ||
@test @inferred(SA_test_hvcat(3.0)) === SMatrix{2,3,Float64}((1,3,3,2,3,3)) | ||
@test @inferred(SA_test_hvcat(1.0im)) === SMatrix{2,3,ComplexF64}((1,1im,1im,2,1im,1im)) | ||
|
||
@test SA[1] === SVector{1,Int}((1)) | ||
|
||
@test_throws ArgumentError("SA[...] matrix rows of length (3, 2) are inconsistent") SA[1 2 3; | ||
4 5] | ||
@test_throws ArgumentError("SA[...] matrix rows of length (2, 3) are inconsistent") SA[1 2; | ||
3 4 5] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normal vcat returns a Vector rather than an n x 1 matrix.