|
216 | 216 |
|
217 | 217 | maybe_assert_scalar_setindexing(args...) = nothing
|
218 | 218 |
|
| 219 | +function Base.setindex!( |
| 220 | + a::TracedRArray{T,N}, v, indices::Union{Int,TracedRNumber{Int}} |
| 221 | +) where {T,N} |
| 222 | + GPUArraysCore.assertscalar( |
| 223 | + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" |
| 224 | + ) |
| 225 | + if indices isa Int |
| 226 | + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) |
| 227 | + end |
| 228 | + indices = scalar_index_to_cartesian( |
| 229 | + TracedUtils.broadcast_to_size(indices, (1,)), size(a) |
| 230 | + ) |
| 231 | + v = v isa Number ? v : vec(v) |
| 232 | + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) |
| 233 | + set_mlir_data!(a, get_mlir_data(res)) |
| 234 | + return a |
| 235 | +end |
| 236 | + |
| 237 | +# Avoid ambiguity |
| 238 | +function Base.setindex!( |
| 239 | + a::TracedRArray{T,1}, v, indices::Union{Int,TracedRNumber{Int}} |
| 240 | +) where {T} |
| 241 | + GPUArraysCore.assertscalar( |
| 242 | + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" |
| 243 | + ) |
| 244 | + if indices isa Int |
| 245 | + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) |
| 246 | + end |
| 247 | + indices = scalar_index_to_cartesian( |
| 248 | + TracedUtils.broadcast_to_size(indices, (1,)), size(a) |
| 249 | + ) |
| 250 | + v = v isa Number ? v : vec(v) |
| 251 | + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) |
| 252 | + set_mlir_data!(a, get_mlir_data(res)) |
| 253 | + return a |
| 254 | +end |
| 255 | + |
| 256 | +function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} |
| 257 | + if !(indices isa TracedRArray) |
| 258 | + indices = collect(indices) |
| 259 | + eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) |
| 260 | + indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) |
| 261 | + end |
| 262 | + res = Ops.scatter_setindex( |
| 263 | + a, |
| 264 | + scalar_index_to_cartesian(vec(indices), size(a)), |
| 265 | + materialize_traced_array(vec(v)), |
| 266 | + ) |
| 267 | + set_mlir_data!(a, get_mlir_data(res)) |
| 268 | + return a |
| 269 | +end |
| 270 | + |
| 271 | +function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} |
| 272 | + v = TracedUtils.broadcast_to_size(v, size(a)) |
| 273 | + set_mlir_data!(a, get_mlir_data(v)) |
| 274 | + return a |
| 275 | +end |
| 276 | + |
| 277 | +function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N} |
| 278 | + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") |
| 279 | + indices = |
| 280 | + materialize_traced_array( |
| 281 | + reshape( |
| 282 | + TracedUtils.promote_to( |
| 283 | + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) |
| 284 | + ), |
| 285 | + 1, |
| 286 | + N, |
| 287 | + ), |
| 288 | + ) .- 1 |
| 289 | + v = v isa Number ? v : vec(v) |
| 290 | + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) |
| 291 | + set_mlir_data!(a, get_mlir_data(res)) |
| 292 | + return a |
| 293 | +end |
| 294 | + |
219 | 295 | function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
|
220 | 296 | maybe_assert_scalar_setindexing(a, indices...)
|
221 | 297 |
|
@@ -244,7 +320,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
|
244 | 320 | indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
|
245 | 321 | indices_list = generate_index_list(indices_list...)
|
246 | 322 | res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
|
247 |
| - a.mlir_data = res.mlir_data |
| 323 | + set_mlir_data!(a, get_mlir_data(res)) |
248 | 324 | return v
|
249 | 325 | end
|
250 | 326 |
|
@@ -275,7 +351,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
|
275 | 351 | ),
|
276 | 352 | 1,
|
277 | 353 | )
|
278 |
| - a.mlir_data = res |
| 354 | + set_mlir_data!(a, res) |
279 | 355 | return v
|
280 | 356 | end
|
281 | 357 |
|
|
0 commit comments