Skip to content

Commit 9779dd5

Browse files
authored
Merge pull request #274 from wsshin/broadcast
Fix several broadcast issues (fixes #197, #199, #200, #242)
2 parents bd09cd6 + 9cf3d2a commit 9779dd5

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

src/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Base.Broadcast:
1010
# This isn't the precise output type, just a placeholder to return from
1111
# promote_containertype, which will control dispatch to our broadcast_c.
1212
_containertype(::Type{<:StaticArray}) = StaticArray
13+
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray
1314

1415
# With the above, the default promote_containertype gives reasonable defaults:
1516
# StaticArray, StaticArray -> StaticArray
@@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A)
3233
_broadcast(f, broadcast_sizes(as...), as...)
3334
end
3435

36+
@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
3537
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
3638
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
3739
@inline broadcast_sizes() = ()
@@ -66,9 +68,9 @@ end
6668
for i = 1:length(sizes)
6769
s = sizes[i]
6870
for j = 1:length(s)
69-
if newsize[j] == 1 || newsize[j] == s[j]
71+
if newsize[j] == 1
7072
newsize[j] = s[j]
71-
else
73+
elseif newsize[j] s[j] && s[j] 1
7274
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
7375
end
7476
end

src/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ end
105105
N = length(S)
106106
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
107107
T0 = eltype(a)
108-
T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1)))
108+
T = :((T1 = Core.Inference.return_type(f, Tuple{$T0}); Core.Inference.return_type(op, Tuple{T1,T1})))
109109

110110
exprs = Array{Expr}(Snew)
111111
itr = [1:n for n Snew]
@@ -235,7 +235,7 @@ end
235235
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
236236
N = length(S)
237237
Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...)
238-
T = Base.promote_op(-, eltype(a), eltype(a))
238+
T = typeof(one(eltype(a)) - one(eltype(a)))
239239

240240
exprs = Array{Expr}(Snew)
241241
itr = [1:n for n = Snew]

test/broadcast.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,34 +43,36 @@ end
4343
end
4444

4545
@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
46+
# Issues #197, #242: broadcast between SArray and row-like SMatrix
4647
m1 = @SMatrix [1 2; 3 4]
4748
m2 = @SMatrix [1 4]
48-
@test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
49-
@test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
49+
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8]
50+
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8]
5051
@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
51-
@test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
52+
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16]
5253
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
53-
@test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
54+
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1]
5455
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
55-
@test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
56+
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0]
5657
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
57-
@test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197
58+
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256]
5859
end
5960

6061
@testset "1x2 StaticMatrix with StaticVector" begin
62+
# Issues #197, #242: broadcast between SVector and row-like SMatrix
6163
m = @SMatrix [1 2]
6264
v = SVector(1, 4)
6365
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6]
6466
@test @inferred(m .+ v) === @SMatrix [2 3; 5 6]
65-
@test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
67+
@test @inferred(v .+ m) === @SMatrix [2 3; 5 6]
6668
@test @inferred(m .* v) === @SMatrix [1 2; 4 8]
67-
@test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
69+
@test @inferred(v .* m) === @SMatrix [1 2; 4 8]
6870
@test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2]
69-
@test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
71+
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2]
7072
@test @inferred(m .- v) === @SMatrix [0 1; -3 -2]
71-
@test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
73+
@test @inferred(v .- m) === @SMatrix [0 -1; 3 2]
7274
@test @inferred(m .^ v) === @SMatrix [1 2; 1 16]
73-
@test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
75+
@test @inferred(v .^ m) === @SMatrix [1 1; 4 16]
7476
end
7577

7678
@testset "StaticVector with StaticVector" begin
@@ -87,11 +89,11 @@ end
8789
@test @inferred(v2 .- v1) === SVector(0, 2)
8890
@test @inferred(v1 .^ v2) === SVector(1, 16)
8991
@test @inferred(v2 .^ v1) === SVector(1, 16)
90-
# test case issue #199
92+
# Issue #199: broadcast with empty SArray
9193
@test @inferred(SVector(1) .+ SVector()) === SVector()
92-
@test_broken @inferred(SVector() .+ SVector(1)) === SVector()
93-
# test case issue #200
94-
@test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5]
94+
@test @inferred(SVector() .+ SVector(1)) === SVector()
95+
# Issue #200: broadcast with RowVector
96+
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
9597
end
9698

9799
@testset "StaticVector with Scalar" begin

0 commit comments

Comments
 (0)