Skip to content

Commit

Permalink
fix: missing scalar indexing check for setindex
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 7, 2025
1 parent 12531c9 commit 3ec714d
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
end

function maybe_assert_scalar_setindexing(
::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})")
end

maybe_assert_scalar_setindexing(args...) = nothing

function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
maybe_assert_scalar_setindexing(a, indices...)

indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
Expand Down

0 comments on commit 3ec714d

Please sign in to comment.