Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: automatic batching of code [currently very wip] #233

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
end

function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
error(1)
@warn(
"""Performing scalar indexing on task $(current_task()).
Invocation resulted in scalar indexing of a TracedRArray.
Expand Down Expand Up @@ -371,6 +372,135 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
return traced2_result
end

# TODO: once we have a generic implementation of `vmap`/`batch` we can simply call that
# from here
function Base.mapslices(f, A::TracedRArray; dims)
isempty(dims) && return map(f, A)

for d in dims
d isa Integer ||
throw(ArgumentError("mapslices: dimension must be an integer, got $d"))
d >= 1 || throw(ArgumentError("mapslices: dimension must be ≥ 1, got $d"))
# Indexing a matrix M[:,1,:] produces a 1-column matrix, but dims=(1,3) here
# would otherwise ignore 3, and slice M[:,i]. Previously this gave error:
# BoundsError: attempt to access 2-element Vector{Any} at index [3]
d > ndims(A) && throw(
ArgumentError(
"mapslices does not accept dimensions > ndims(A) = $(ndims(A)), got $d"
),
)
end

args = (A,)
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
f,
args,
(),
string(f) * "_mapslice",
false;
batchdims=filter(d -> d ∉ dims, 1:ndims(A)),
)

invmap = IdDict()
for (k, v) in seen_args
invmap[v] = k
end

keys_seen = [k for k in keys(seen_args) if k isa TracedType]
input_shapes = size.(keys_seen)

function shape_fn(x)
counter = 0
return ntuple(ndims(A)) do d
d in dims && return size(A, d)
counter += 1
return size(x, counter)
end
end

out_shapes = map(shape_fn, linear_results)
@assert allequal(out_shapes) "out_shapes are $(out_shapes)"
out_shape = first(out_shapes)

in_tys2 = [mlir_type(invmap[arg]) for arg in linear_args]

out_tys2 = [
MLIR.IR.TensorType(out_shapes[i], MLIR.IR.Type(eltype(arg))) for
(i, arg) in enumerate(linear_results)
]

fname = get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))

batch_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = get_argidx(a)
if idx == 1 && fnwrap
push_val!(batch_inputs, f, path[3:end])
else
if fnwrap
idx -= 1
end
push_val!(batch_inputs, args[idx], path[3:end])
end
end

res = MLIR.Dialects.enzyme.batch(
batch_inputs;
outputs=out_tys2,
fn=fname,
batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in out_shape]),
)

residx = 1

for a in linear_results
if has_residx(a)
path = get_residx(a)
set!(result, path[2:end], MLIR.IR.result(res, residx))
residx += 1
else
idx, path = get_argidx(a)
if idx == 1 && fnwrap
set!(f, path[3:end], MLIR.IR.result(res, residx))
residx += 1
else
if fnwrap
idx -= 1
end
set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
residx += 1
end
end
end

seen_results = OrderedIdDict()
traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=out_shape)

func2.operation = MLIR.API.MlirOperation(C_NULL)

return traced2_result
end

function Base._eachslice(
A::AnyTracedRArray{T,N}, dims::NTuple{M,Integer}, drop::Bool
) where {T,N,M}
Base._slice_check_dims(N, dims...)
all_slices = [A]
for dim in dims
partial_slices = []
for i in axes(A, dim)
for slice in all_slices
push!(partial_slices, selectdim(slice, dim, i:i))
end
end
all_slices = partial_slices
end
drop && (all_slices = map(x -> dropdims(x; dims), all_slices))
return all_slices
end

for (jlop, hloop, hlocomp, merge) in
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))
@eval function $jlop(
Expand Down
40 changes: 17 additions & 23 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,7 @@ end
append_path(path, i) = (path..., i)

function make_tracer(
seen,
@nospecialize(prev::RT),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
kwargs...,
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
) where {RT}
if haskey(seen, prev)
return seen[prev]
Expand All @@ -251,7 +245,7 @@ function make_tracer(
for i in 1:nf
if isdefined(prev, i)
xi = Base.getfield(prev, i)
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
xi2 = make_tracer(seen, xi, append_path(path, i), mode; kwargs...)
if xi !== xi2
changed = true
end
Expand All @@ -274,7 +268,7 @@ function make_tracer(
for i in 1:nf
if isdefined(prev, i)
xi = Base.getfield(prev, i)
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
xi2 = make_tracer(seen, xi, append_path(path, i), mode; kwargs...)
if xi !== xi2
changed = true
end
Expand Down Expand Up @@ -333,6 +327,7 @@ function make_tracer(
mode;
toscalar=false,
tobatch=nothing,
batchdims=nothing,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add a batchdims to linearization? shouldn't be needed and can make it harder to linearize

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporary means for prototyping. It is just a generalization of toscalar, so I want to fuse those options

kwargs...,
) where {T,N}
if mode == ConcreteToTraced
Expand All @@ -349,10 +344,16 @@ function make_tracer(
if haskey(seen, prev)
return seen[prev]
end
# TODO: Check `toscalar` is not set together with `batchdims`?
# TODO: Unify toscalar and batchdims?
res = if toscalar
TracedRNumber{T}((path,), nothing)
elseif batchdims !== nothing
TracedRArray{T,length(batchdims)}(
(path,), nothing, map(Base.Fix1(size, prev), batchdims)
)
elseif tobatch !== nothing
error("This should not happen...")
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
else
TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
end
Expand All @@ -379,6 +380,7 @@ function make_tracer(
mode;
tobatch=nothing,
toscalar=false,
batchdims=nothing,
kwargs...,
) where {T}
if mode == ConcreteToTraced
Expand All @@ -397,6 +399,8 @@ function make_tracer(
end
res = if toscalar
TracedRNumber{T}((path,), nothing)
elseif batchdims !== nothing
error("This should not happen...")
elseif tobatch !== nothing
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
else
Expand Down Expand Up @@ -475,21 +479,11 @@ end
make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev

function make_tracer(
seen,
@nospecialize(prev::Complex{RT}),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
kwargs...,
seen, @nospecialize(prev::Complex{RT}), @nospecialize(path), mode; kwargs...
) where {RT}
return Complex(
make_tracer(
seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs...
),
make_tracer(
seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs...
),
make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...),
make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...),
)
end

Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ function apply(f, args...; kwargs...)
return f(args...; kwargs...)
end

# TODO: Generalize batchdims to be propagated to args differently?
function make_mlir_fn(
f,
args,
kwargs,
name="main",
concretein=true;
toscalar=false,
batchdims=nothing,
return_dialect=:func,
no_args_in_result::Bool=false,
construct_function_without_args::Bool=false,
Expand All @@ -54,6 +56,7 @@ function make_mlir_fn(
name,
concretein;
toscalar,
batchdims,
return_dialect,
no_args_in_result,
construct_function_without_args,
Expand All @@ -70,6 +73,7 @@ function make_mlir_fn(
(:args, i),
concretein ? ConcreteToTraced : TracedSetPath;
toscalar,
batchdims,
track_numbers=construct_function_without_args ? (Number,) : (),
)
end
Expand Down
Loading