From 3ec714d9ca07733c34aec0b631cbca917eb45497 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 10:55:55 -0500 Subject: [PATCH 1/6] fix: missing scalar indexing check for setindex --- src/TracedRArray.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index db9f00d11..fe7ba4db2 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -172,7 +172,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end +function maybe_assert_scalar_setindexing( + ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} +) where {T,N} + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") +end + +maybe_assert_scalar_setindexing(args...) = nothing + function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} + maybe_assert_scalar_setindexing(a, indices...) + indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) i isa CartesianIndex && return Tuple(i) From 57ee5fb4b641f2da95f21f870b3d8d5a1f450280 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 11:26:16 -0500 Subject: [PATCH 2/6] fix: missing copyto! --- src/TracedRArray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index fe7ba4db2..9a2542cf8 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -483,6 +483,10 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end +function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T,T2,N} + return copyto!(dest, Ops.convert(TracedRArray{T,N}, src)) +end + function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest From 34d72f2b33408c6062b34c01384c5cffb912f3c5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 Jan 2025 12:07:52 -0500 Subject: [PATCH 3/6] Update src/TracedRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9a2542cf8..8244bfc24 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -175,7 +175,7 @@ end function maybe_assert_scalar_setindexing( ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} - GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") + return GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") end maybe_assert_scalar_setindexing(args...) = nothing From 68ea368f6d002c6a1eddc87ec8efcaab56bbf862 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 16:18:37 -0500 Subject: [PATCH 4/6] fix: overload Enzyme.onehot to avoid scalar indexing --- src/Enzyme.jl | 7 +++++++ src/Reactant.jl | 3 +++ 2 files changed, 10 insertions(+) create mode 100644 src/Enzyme.jl diff --git a/src/Enzyme.jl b/src/Enzyme.jl new file mode 100644 index 000000000..e0684b9f0 --- /dev/null +++ b/src/Enzyme.jl @@ -0,0 +1,7 @@ +# TODO: move the overload_autodiff here as well + +# The default `onehot` will lead to scalar indexing +function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N} + x_arr = zeros(T, size(x)) + return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T, N}), Enzyme.onehot(x_arr)) +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 3102c7531..ce1a86a19 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -203,6 +203,9 @@ end include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") +# Other Integrations +include("Enzyme.jl") + const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") From 0b7ba8201d4d74e2518113fa166371d3d1c496bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 16:23:51 -0500 Subject: [PATCH 5/6] fix: mark tests with allowscalar --- test/control_flow.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/control_flow.jl b/test/control_flow.jl index 9b4ee9fcf..20b52c4e1 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -355,7 +355,7 @@ function condition10_condition_with_setindex(x) @trace if sum(x) > 0 x[:, 1] = -1.0 else - x[1, 1] = 1.0 + @allowscalar x[1, 1] = 1.0 end return x end @@ -457,7 +457,7 @@ end function for_with_step(x) @trace for i in 10:3:22 - x[i] = i * i + @allowscalar x[i] = i * i end return x end @@ -539,7 +539,7 @@ function cumsum!(x) v = zero(eltype(x)) @trace for i in 1:length(x) v += @allowscalar x[i] - x[i] = v + @allowscalar x[i] = v end return x end From cc92473a9511a38aadbddcdbb02a918af203c8c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 16:24:31 -0500 Subject: [PATCH 6/6] chore: formatting --- src/Enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index e0684b9f0..366352b5d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -3,5 +3,5 @@ # The default `onehot` will lead to scalar indexing function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N} x_arr = zeros(T, size(x)) - return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T, N}), Enzyme.onehot(x_arr)) + return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T,N}), Enzyme.onehot(x_arr)) end