Skip to content

Commit 004e9ad

Browse files
committed
Add parameterized SA{T} initializer
1 parent 479c86d commit 004e9ad

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/initializers.jl

+26-18
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,46 @@
11
"""
2-
SA[ array initializer ]
2+
SA[ elements ]
3+
SA{T}[ elements ]
34
4-
A type for initializing static array literals using array construction syntax.
5-
Returns an `SVector` or `SMatrix`.
5+
Create `SArray` literals using array construction syntax. The element type is
6+
inferred by promoting `elements` to a common type or set to `T` when `T` is
7+
provided explicitly.
68
79
# Examples:
810
9-
* `SA[x, y]` creates a length-2 SVector
10-
* `SA[a b; c d]` creates a 2×2 SMatrix
11-
* `SA[a b]` creates a 1×2 SMatrix
11+
* `SA[1.0, 2.0]` creates a length-2 `SVector` of `Float64` elements.
12+
* `SA[1 2; 3 4]` creates a 2×2 SMatrix of `Int`s.
13+
* `SA[1 2]` creates a 1×2 SMatrix of `Int`s.
14+
* `SA{Float32}[1, 2]` creates a length-2 `SVector` of `Float32` elements.
1215
"""
13-
struct SA ; end
16+
struct SA{T} ; end
1417

15-
Base.getindex(::Type{SA}, xs...) = SVector(xs)
16-
Base.typed_vcat(::Type{SA}, xs::Number...) = SVector(xs)
17-
Base.typed_hcat(::Type{SA}, xs::Number...) = SMatrix{1,length(xs)}(xs)
18+
@inline similar_type(::Type{SA}, ::Size{S}) where {S} = SArray{Tuple{S...}}
19+
@inline similar_type(::Type{SA{T}}, ::Size{S}) where {T,S} = SArray{Tuple{S...}, T}
1820

19-
Base.@pure function _SA_hvcat_transposed_type(rows)
21+
Base.@pure _SA_type(sa::Type{SA}, len::Int) = SVector{len}
22+
Base.@pure _SA_type(sa::Type{SA{T}}, len::Int) where {T} = SVector{len,T}
23+
24+
@inline Base.getindex(sa::Type{<:SA}, xs...) where T = similar_type(sa, Size(length(xs)))(xs)
25+
@inline Base.typed_vcat(sa::Type{<:SA}, xs::Number...) where T = similar_type(sa, Size(length(xs)))(xs)
26+
@inline Base.typed_hcat(sa::Type{<:SA}, xs::Number...) where T = similar_type(sa, Size(1,length(xs)))(xs)
27+
28+
Base.@pure function _SA_hvcat_transposed_size(rows)
2029
M = rows[1]
2130
if any(r->r != M, rows)
2231
# @pure may not throw... probably. See
2332
# https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459
2433
return nothing
2534
end
26-
SMatrix{M,length(rows)}
35+
Size(M, length(rows))
2736
end
2837

29-
@inline function Base.typed_hvcat(::Type{SA}, rows::Dims, xs::Number...)
30-
mtype = _SA_hvcat_transposed_type(rows)
31-
if mtype === nothing
38+
@inline function Base.typed_hvcat(sa::Type{<:SA}, rows::Dims, xs::Number...) where T
39+
msize = _SA_hvcat_transposed_size(rows)
40+
if msize === nothing
3241
throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent"))
3342
end
34-
# hvcat lowering is row major ordering, so must transpose
35-
transpose(mtype(xs))
43+
# hvcat lowering is row major ordering, so we must transpose
44+
transpose(similar_type(sa, msize)(xs))
3645
end
3746

38-

test/initializers.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1-
SA_test_ref(x) = SA[1,x,x]
1+
SA_test_ref(x) = SA[1,x,x]
2+
SA_test_ref(x,T) = SA{T}[1,x,x]
23
@test @inferred(SA_test_ref(2)) === SVector{3,Int}((1,2,2))
34
@test @inferred(SA_test_ref(2.0)) === SVector{3,Float64}((1,2,2))
5+
@test @inferred(SA_test_ref(2,Float32)) === SVector{3,Float32}((1,2,2))
46

5-
SA_test_vcat(x) = SA[1;x;x]
7+
SA_test_vcat(x) = SA[1;x;x]
8+
SA_test_vcat(x,T) = SA{T}[1;x;x]
69
@test @inferred(SA_test_vcat(2)) === SVector{3,Int}((1,2,2))
710
@test @inferred(SA_test_vcat(2.0)) === SVector{3,Float64}((1,2,2))
11+
@test @inferred(SA_test_vcat(2,Float32)) === SVector{3,Float32}((1,2,2))
812

9-
SA_test_hcat(x) = SA[1 x x]
13+
SA_test_hcat(x) = SA[1 x x]
14+
SA_test_hcat(x,T) = SA{T}[1 x x]
1015
@test @inferred(SA_test_hcat(2)) === SMatrix{1,3,Int}((1,2,2))
1116
@test @inferred(SA_test_hcat(2.0)) === SMatrix{1,3,Float64}((1,2,2))
17+
@test @inferred(SA_test_hcat(2,Float32)) === SMatrix{1,3,Float32}((1,2,2))
1218

1319
SA_test_hvcat(x) = SA[1 x x;
1420
x 2 x]
21+
SA_test_hvcat(x,T) = SA{T}[1 x x;
22+
x 2 x]
1523
@test @inferred(SA_test_hvcat(3)) === SMatrix{2,3,Int}((1,3,3,2,3,3))
1624
@test @inferred(SA_test_hvcat(3.0)) === SMatrix{2,3,Float64}((1,3,3,2,3,3))
1725
@test @inferred(SA_test_hvcat(1.0im)) === SMatrix{2,3,ComplexF64}((1,1im,1im,2,1im,1im))
26+
@test @inferred(SA_test_hvcat(3,Float32)) === SMatrix{2,3,Float32}((1,3,3,2,3,3))
1827

1928
@test SA[1] === SVector{1,Int}((1))
2029

0 commit comments

Comments
 (0)