diff --git a/src/Enzyme.jl b/src/Enzyme.jl new file mode 100644 index 000000000..366352b5d --- /dev/null +++ b/src/Enzyme.jl @@ -0,0 +1,7 @@ +# TODO: move the overload_autodiff here as well + +# The default `onehot` will lead to scalar indexing +function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N} + x_arr = zeros(T, size(x)) + return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T,N}), Enzyme.onehot(x_arr)) +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 3102c7531..ce1a86a19 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -203,6 +203,9 @@ end include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") +# Other Integrations +include("Enzyme.jl") + const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index db9f00d11..8244bfc24 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -172,7 +172,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end +function maybe_assert_scalar_setindexing( + ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} +) where {T,N} + return GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") +end + +maybe_assert_scalar_setindexing(args...) = nothing + function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} + maybe_assert_scalar_setindexing(a, indices...) + indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) i isa CartesianIndex && return Tuple(i) @@ -473,6 +483,10 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end +function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T,T2,N} + return copyto!(dest, Ops.convert(TracedRArray{T,N}, src)) +end + function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest diff --git a/test/control_flow.jl b/test/control_flow.jl index 9b4ee9fcf..20b52c4e1 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -355,7 +355,7 @@ function condition10_condition_with_setindex(x) @trace if sum(x) > 0 x[:, 1] = -1.0 else - x[1, 1] = 1.0 + @allowscalar x[1, 1] = 1.0 end return x end @@ -457,7 +457,7 @@ end function for_with_step(x) @trace for i in 10:3:22 - x[i] = i * i + @allowscalar x[i] = i * i end return x end @@ -539,7 +539,7 @@ function cumsum!(x) v = zero(eltype(x)) @trace for i in 1:length(x) v += @allowscalar x[i] - x[i] = v + @allowscalar x[i] = v end return x end