Skip to content

Commit

Permalink
Rename and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored and wsmoses committed Jul 9, 2024
1 parent 34d8a1b commit a35fb46
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
16 changes: 7 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
"""
pick_chunksize(totalsize, mode, ftype, return_activity, argtypes...)
pick_batchsize(totalsize, mode, ftype, return_activity, argtypes...)
Return a reasonable chunk size for batched differentiation.
Return a reasonable batch size for batched differentiation.
!!! warning
This function is experimental, and not part of the public API.
"""
function pick_chunksize(
totalsize::Integer,
mode::Mode,
ftype::Type,
return_activity,::Type{<:Annotation},
argtypes::Vararg{Type{<:Annotation}, Nargs}
) where {Nargs}
function pick_batchsize(totalsize::Integer,
mode::Mode,
ftype::Type,
return_activity, ::Type{<:Annotation},
argtypes::Vararg{Type{<:Annotation},Nargs}) where {Nargs}
return min(totalsize, 16)
end

Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using Enzyme: pick_chunksize
using Enzyme: Active, Duplicated, pick_batchsize
using EnzymeCore: Forward
using Test

mode = Forward
ftype = typeof(sum)
argtypes = typeof.((Duplicated(ones(1), zeros(1)),))
@test pick_chunksize(1, mode, ftype, Active, argtypes...) == 1
@test pick_batchsize(1, mode, ftype, Active, argtypes...) == 1

argtypes = typeof.((Duplicated(ones(100), zeros(100)),))
@test pick_chunksize(100, mode, ftype, Active, argtypes...) == 16
@test pick_batchsize(100, mode, ftype, Active, argtypes...) == 16

0 comments on commit a35fb46

Please sign in to comment.