diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e7789d660..351b5321be 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -123,6 +123,7 @@ include("gradientutils.jl") include("utils.jl") include("compiler.jl") include("internal_rules.jl") +include("sugar.jl") import .Compiler: CompilationException diff --git a/src/sugar.jl b/src/sugar.jl new file mode 100644 index 0000000000..4561800040 --- /dev/null +++ b/src/sugar.jl @@ -0,0 +1,15 @@ +""" + pick_batchsize(totalsize, mode, ftype, return_activity, argtypes...) + +Return a reasonable batch size for batched differentiation. + +!!! warning + This function is experimental, and not part of the public API. +""" +function pick_batchsize(totalsize::Integer, + mode::Mode, + ftype::Type, + return_activity, ::Type{<:Annotation}, + argtypes::Vararg{Type{<:Annotation},Nargs}) where {Nargs} + return min(totalsize, 16) +end diff --git a/test/runtests.jl b/test/runtests.jl index 902b9e4f65..0325e9ebd2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,6 +84,7 @@ end include("abi.jl") include("typetree.jl") +include("sugar.jl") include("rules.jl") include("rrules.jl") diff --git a/test/sugar.jl b/test/sugar.jl new file mode 100644 index 0000000000..0baab151f2 --- /dev/null +++ b/test/sugar.jl @@ -0,0 +1,10 @@ +using Enzyme: Active, Duplicated, Forward, pick_batchsize +using Test + +mode = Forward +ftype = typeof(sum) +argtypes = typeof.((Duplicated(ones(1), zeros(1)),)) +@test pick_batchsize(1, mode, ftype, Active, argtypes...) == 1 + +argtypes = typeof.((Duplicated(ones(100), zeros(100)),)) +@test pick_batchsize(100, mode, ftype, Active, argtypes...) == 16