From e165992f1e5eb31347cc98f46bf1c002c82fcd7f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 6 Nov 2024 12:29:38 -0500 Subject: [PATCH 1/2] feat: working towards a generic mapslices using batch --- src/TracedRArray.jl | 106 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a5beddfe5..12468e219 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -44,6 +44,7 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) end function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} + error(1) @warn( """Performing scalar indexing on task $(current_task()). Invocation resulted in scalar indexing of a TracedRArray. @@ -371,6 +372,111 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return traced2_result end +# TODO: once we have a generic implementation of `vmap`/`batch` we can simply call that +# from here +function Base.mapslices(f, A::TracedRArray; dims) + isempty(dims) && return map(f, A) + + for d in dims + d isa Integer || + throw(ArgumentError("mapslices: dimension must be an integer, got $d")) + d >= 1 || throw(ArgumentError("mapslices: dimension must be ≥ 1, got $d")) + # Indexing a matrix M[:,1,:] produces a 1-column matrix, but dims=(1,3) here + # would otherwise ignore 3, and slice M[:,i]. Previously this gave error: + # BoundsError: attempt to access 2-element Vector{Any} at index [3] + d > ndims(A) && throw( + ArgumentError( + "mapslices does not accept dimensions > ndims(A) = $(ndims(A)), got $d" + ), + ) + end + + # Apply the function to the first slice in order to determine the next steps + batch_inputs = eachslice(A; dims) + + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + f, (first(batch_inputs),), (), string(f) * "_mapslice", false + ) + + @assert traced_result isa Union{TracedRArray,TracedRNumber} "Expected TracedRArray or TracedRNumber as result." + + input_shapes = size.(batch_inputs) + output_shape = size(traced_result) + + input_types = map(mlir_type, batch_inputs) + output_types = [mlir_type(traced_result) for _ in 1:length(batch_inputs)] + + @show func2 + + fname = get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + counter = 0 + batch_shape = ntuple(ndims(A)) do d + d in dims && return size(A, d) + counter += 1 + return size(traced_result, counter) + end + + res = MLIR.Dialects.enzyme.batch( + map(get_mlir_data, batch_inputs); + outputs=[mlir_type(traced_result)], + fn=fname, + batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in batch_shape]), + ) + + @show res + + # residx = 1 + + # for a in linear_results + # if has_residx(a) + # path = get_residx(a) + # set!(result, path[2:end], MLIR.IR.result(res, residx)) + # residx += 1 + # else + # idx, path = get_argidx(a) + # if idx == 1 && fnwrap + # set!(f, path[3:end], MLIR.IR.result(res, residx)) + # residx += 1 + # else + # if fnwrap + # idx -= 1 + # end + # set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + # residx += 1 + # end + # end + # end + + # seen_results = OrderedIdDict() + # traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=OutShape) + + # func2.operation = MLIR.API.MlirOperation(C_NULL) + + # return traced2_result + + return error(1) +end + +function Base._eachslice( + A::AnyTracedRArray{T,N}, dims::NTuple{M,Integer}, drop::Bool +) where {T,N,M} + Base._slice_check_dims(N, dims...) + all_slices = [A] + for dim in dims + partial_slices = [] + for i in axes(A, dim) + for slice in all_slices + push!(partial_slices, selectdim(slice, dim, i:i)) + end + end + all_slices = partial_slices + end + drop && (all_slices = map(x -> dropdims(x; dims), all_slices)) + return all_slices +end + for (jlop, hloop, hlocomp, merge) in ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) @eval function $jlop( From 0eaf929c4eb5330b94b830ef218437aeac06e17d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 6 Nov 2024 13:03:35 -0500 Subject: [PATCH 2/2] fix: integrate batchdims into the make_mlir_fn --- src/TracedRArray.jl | 124 ++++++++++++++++++++++++++------------------ src/Tracing.jl | 40 ++++++-------- src/utils.jl | 4 ++ 3 files changed, 95 insertions(+), 73 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 12468e219..405362159 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -391,72 +391,96 @@ function Base.mapslices(f, A::TracedRArray; dims) ) end - # Apply the function to the first slice in order to determine the next steps - batch_inputs = eachslice(A; dims) - + args = (A,) fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - f, (first(batch_inputs),), (), string(f) * "_mapslice", false + f, + args, + (), + string(f) * "_mapslice", + false; + batchdims=filter(d -> d ∉ dims, 1:ndims(A)), ) - @assert traced_result isa Union{TracedRArray,TracedRNumber} "Expected TracedRArray or TracedRNumber as result." + invmap = IdDict() + for (k, v) in seen_args + invmap[v] = k + end + + keys_seen = [k for k in keys(seen_args) if k isa TracedType] + input_shapes = size.(keys_seen) - input_shapes = size.(batch_inputs) - output_shape = size(traced_result) + function shape_fn(x) + counter = 0 + return ntuple(ndims(A)) do d + d in dims && return size(A, d) + counter += 1 + return size(x, counter) + end + end - input_types = map(mlir_type, batch_inputs) - output_types = [mlir_type(traced_result) for _ in 1:length(batch_inputs)] + out_shapes = map(shape_fn, linear_results) + @assert allequal(out_shapes) "out_shapes are $(out_shapes)" + out_shape = first(out_shapes) - @show func2 + in_tys2 = [mlir_type(invmap[arg]) for arg in linear_args] + + out_tys2 = [ + MLIR.IR.TensorType(out_shapes[i], MLIR.IR.Type(eltype(arg))) for + (i, arg) in enumerate(linear_results) + ] fname = get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - counter = 0 - batch_shape = ntuple(ndims(A)) do d - d in dims && return size(A, d) - counter += 1 - return size(traced_result, counter) + batch_inputs = MLIR.IR.Value[] + + for a in linear_args + idx, path = get_argidx(a) + if idx == 1 && fnwrap + push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + push_val!(batch_inputs, args[idx], path[3:end]) + end end res = MLIR.Dialects.enzyme.batch( - map(get_mlir_data, batch_inputs); - outputs=[mlir_type(traced_result)], + batch_inputs; + outputs=out_tys2, fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in batch_shape]), + batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in out_shape]), ) - @show res - - # residx = 1 - - # for a in linear_results - # if has_residx(a) - # path = get_residx(a) - # set!(result, path[2:end], MLIR.IR.result(res, residx)) - # residx += 1 - # else - # idx, path = get_argidx(a) - # if idx == 1 && fnwrap - # set!(f, path[3:end], MLIR.IR.result(res, residx)) - # residx += 1 - # else - # if fnwrap - # idx -= 1 - # end - # set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) - # residx += 1 - # end - # end - # end - - # seen_results = OrderedIdDict() - # traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=OutShape) - - # func2.operation = MLIR.API.MlirOperation(C_NULL) - - # return traced2_result - - return error(1) + residx = 1 + + for a in linear_results + if has_residx(a) + path = get_residx(a) + set!(result, path[2:end], MLIR.IR.result(res, residx)) + residx += 1 + else + idx, path = get_argidx(a) + if idx == 1 && fnwrap + set!(f, path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + else + if fnwrap + idx -= 1 + end + set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + end + end + end + + seen_results = OrderedIdDict() + traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=out_shape) + + func2.operation = MLIR.API.MlirOperation(C_NULL) + + return traced2_result end function Base._eachslice( diff --git a/src/Tracing.jl b/src/Tracing.jl index 30a617cc3..5ba572a79 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -228,13 +228,7 @@ end append_path(path, i) = (path..., i) function make_tracer( - seen, - @nospecialize(prev::RT), - @nospecialize(path), - mode; - toscalar=false, - tobatch=nothing, - kwargs..., + seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... ) where {RT} if haskey(seen, prev) return seen[prev] @@ -251,7 +245,7 @@ function make_tracer( for i in 1:nf if isdefined(prev, i) xi = Base.getfield(prev, i) - xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch) + xi2 = make_tracer(seen, xi, append_path(path, i), mode; kwargs...) if xi !== xi2 changed = true end @@ -274,7 +268,7 @@ function make_tracer( for i in 1:nf if isdefined(prev, i) xi = Base.getfield(prev, i) - xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch) + xi2 = make_tracer(seen, xi, append_path(path, i), mode; kwargs...) if xi !== xi2 changed = true end @@ -333,6 +327,7 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + batchdims=nothing, kwargs..., ) where {T,N} if mode == ConcreteToTraced @@ -349,10 +344,16 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end + # TODO: Check `toscalar` is not set together with `batchdims`? + # TODO: Unify toscalar and batchdims? res = if toscalar TracedRNumber{T}((path,), nothing) + elseif batchdims !== nothing + TracedRArray{T,length(batchdims)}( + (path,), nothing, map(Base.Fix1(size, prev), batchdims) + ) elseif tobatch !== nothing - error("This should not happen...") + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) else TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) end @@ -379,6 +380,7 @@ function make_tracer( mode; tobatch=nothing, toscalar=false, + batchdims=nothing, kwargs..., ) where {T} if mode == ConcreteToTraced @@ -397,6 +399,8 @@ function make_tracer( end res = if toscalar TracedRNumber{T}((path,), nothing) + elseif batchdims !== nothing + error("This should not happen...") elseif tobatch !== nothing TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) else @@ -475,21 +479,11 @@ end make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev function make_tracer( - seen, - @nospecialize(prev::Complex{RT}), - @nospecialize(path), - mode; - toscalar=false, - tobatch=nothing, - kwargs..., + seen, @nospecialize(prev::Complex{RT}), @nospecialize(path), mode; kwargs... ) where {RT} return Complex( - make_tracer( - seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... - ), - make_tracer( - seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs... - ), + make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), + make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), ) end diff --git a/src/utils.jl b/src/utils.jl index 46d94e75b..8b3663fe1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -33,6 +33,7 @@ function apply(f, args...; kwargs...) return f(args...; kwargs...) end +# TODO: Generalize batchdims to be propagated to args differently? function make_mlir_fn( f, args, @@ -40,6 +41,7 @@ function make_mlir_fn( name="main", concretein=true; toscalar=false, + batchdims=nothing, return_dialect=:func, no_args_in_result::Bool=false, construct_function_without_args::Bool=false, @@ -54,6 +56,7 @@ function make_mlir_fn( name, concretein; toscalar, + batchdims, return_dialect, no_args_in_result, construct_function_without_args, @@ -70,6 +73,7 @@ function make_mlir_fn( (:args, i), concretein ? ConcreteToTraced : TracedSetPath; toscalar, + batchdims, track_numbers=construct_function_without_args ? (Number,) : (), ) end