Skip to content

Commit bfc7b58

Browse files
authored
feat: support more set indexing (#625)
* feat: support more set indexing * fix: tests
1 parent 6e4c6a8 commit bfc7b58

File tree

3 files changed

+126
-4
lines changed

3 files changed

+126
-4
lines changed

src/Ops.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,11 +1479,13 @@ instead.
14791479
@noinline function scatter_setindex(
14801480
dest::TracedRArray{T,N},
14811481
scatter_indices::TracedRArray{Int64,2},
1482-
updates::TracedRArray{T,1},
1483-
) where {T,N}
1482+
updates::TracedRArray{T2,1},
1483+
) where {T,N,T2}
14841484
@assert length(updates) == size(scatter_indices, 1)
14851485
@assert size(scatter_indices, 2) == N
14861486

1487+
updates = convert(TracedRArray{T,1}, updates)
1488+
14871489
update_computation = MLIR.IR.Region()
14881490
block = MLIR.IR.Block(
14891491
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],

src/TracedRArray.jl

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,82 @@ end
216216

217217
maybe_assert_scalar_setindexing(args...) = nothing
218218

219+
function Base.setindex!(
220+
a::TracedRArray{T,N}, v, indices::Union{Int,TracedRNumber{Int}}
221+
) where {T,N}
222+
GPUArraysCore.assertscalar(
223+
"setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})"
224+
)
225+
if indices isa Int
226+
indices = TracedUtils.promote_to(TracedRNumber{Int}, indices)
227+
end
228+
indices = scalar_index_to_cartesian(
229+
TracedUtils.broadcast_to_size(indices, (1,)), size(a)
230+
)
231+
v = v isa Number ? v : vec(v)
232+
res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,)))
233+
set_mlir_data!(a, get_mlir_data(res))
234+
return a
235+
end
236+
237+
# Avoid ambiguity
238+
function Base.setindex!(
239+
a::TracedRArray{T,1}, v, indices::Union{Int,TracedRNumber{Int}}
240+
) where {T}
241+
GPUArraysCore.assertscalar(
242+
"setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})"
243+
)
244+
if indices isa Int
245+
indices = TracedUtils.promote_to(TracedRNumber{Int}, indices)
246+
end
247+
indices = scalar_index_to_cartesian(
248+
TracedUtils.broadcast_to_size(indices, (1,)), size(a)
249+
)
250+
v = v isa Number ? v : vec(v)
251+
res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,)))
252+
set_mlir_data!(a, get_mlir_data(res))
253+
return a
254+
end
255+
256+
function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N}
257+
if !(indices isa TracedRArray)
258+
indices = collect(indices)
259+
eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices])
260+
indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices)
261+
end
262+
res = Ops.scatter_setindex(
263+
a,
264+
scalar_index_to_cartesian(vec(indices), size(a)),
265+
materialize_traced_array(vec(v)),
266+
)
267+
set_mlir_data!(a, get_mlir_data(res))
268+
return a
269+
end
270+
271+
function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N}
272+
v = TracedUtils.broadcast_to_size(v, size(a))
273+
set_mlir_data!(a, get_mlir_data(v))
274+
return a
275+
end
276+
277+
function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N}
278+
GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})")
279+
indices =
280+
materialize_traced_array(
281+
reshape(
282+
TracedUtils.promote_to(
283+
TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...))
284+
),
285+
1,
286+
N,
287+
),
288+
) .- 1
289+
v = v isa Number ? v : vec(v)
290+
res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,)))
291+
set_mlir_data!(a, get_mlir_data(res))
292+
return a
293+
end
294+
219295
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
220296
maybe_assert_scalar_setindexing(a, indices...)
221297

@@ -244,7 +320,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
244320
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
245321
indices_list = generate_index_list(indices_list...)
246322
res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
247-
a.mlir_data = res.mlir_data
323+
set_mlir_data!(a, get_mlir_data(res))
248324
return v
249325
end
250326

@@ -275,7 +351,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
275351
),
276352
1,
277353
)
278-
a.mlir_data = res
354+
set_mlir_data!(a, res)
279355
return v
280356
end
281357

test/indexing.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,47 @@ end
233233
@test @jit(fn(x_ra, idx1_ra, idx2_ra, idx3))
234234
fn(Array(x_ra), Array(idx1_ra), Array(idx2_ra), idx3)
235235
end
236+
237+
function issue_617(outf, fr, pr, I)
238+
tmp = fr .* reshape(pr, size(fr))
239+
outv = @view outf[I]
240+
vtmp = vec(tmp)
241+
outv .= vtmp
242+
return outf
243+
end
244+
245+
@testset "issue #617" begin
246+
N, M = 4, 6
247+
248+
f = rand(ComplexF64, N, N)
249+
p = rand(ComplexF64, N * N)
250+
I = 1:(N^2)
251+
out = rand(ComplexF64, M, M)
252+
253+
fr = Reactant.to_rarray(f)
254+
pr = Reactant.to_rarray(p)
255+
outr = Reactant.to_rarray(out)
256+
Ir = Reactant.to_rarray(I)
257+
258+
@test @jit(issue_617(outr, fr, pr, Ir)) issue_617(out, f, p, I)
259+
end
260+
261+
function scalar_setindex(x, idx, val)
262+
@allowscalar x[idx] = val
263+
return x
264+
end
265+
266+
@testset "scalar setindex" begin
267+
x = zeros(4, 4)
268+
x_ra = Reactant.to_rarray(x)
269+
270+
@test @jit(scalar_setindex(x_ra, 1, 1)) scalar_setindex(x, 1, 1)
271+
@test @allowscalar x_ra[1] == 1
272+
273+
x = zeros(4, 4)
274+
x_ra = Reactant.to_rarray(x)
275+
276+
@test @jit(scalar_setindex(x_ra, ConcreteRNumber(1), 1)) scalar_setindex(x, 1, 1)
277+
@test @allowscalar x_ra[1] == 1
278+
end
279+

0 commit comments

Comments
 (0)