Skip to content

Commit

Permalink
docs: setup batching tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 19, 2025
1 parent 83aceac commit 4c24ec2
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 6 deletions.
10 changes: 9 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ examples = [
pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" => ["Overview" => "tutorials/index.md"],
"Tutorials" => [
"Overview" => "tutorials/index.md",
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
Expand All @@ -37,6 +40,11 @@ pages = [
"Func" => "api/func.md",
"StableHLO" => "api/stablehlo.md",
"VHLO" => "api/vhlo.md",
"GPU" => "api/gpu.md",
"LLVM" => "api/llvm.md",
"NVVM" => "api/nvvm.md",
"TPU" => "api/tpu.md",
"Triton" => "api/triton.md",
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
Expand Down
15 changes: 14 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,16 @@ export default defineConfig({
{ text: "Home", link: "/" },
{ text: "Getting Started", link: "/introduction" },
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
{ text: "Tutorials", link: "/tutorials/" },
{
text: "Tutorials",
items: [
{ text: "Overview", link: "/tutorials/" },
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching"
},
],
},
{
text: "API",
items: [
Expand Down Expand Up @@ -105,6 +114,10 @@ export default defineConfig({
collapsed: false,
items: [
{ text: "Overview", link: "/tutorials/" },
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching",
},
],
},
"/api/": {
Expand Down
3 changes: 3 additions & 0 deletions docs/src/tutorials/batching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)


2 changes: 2 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Tutorials

We are currently working on adding tutorials to Reactant!! Please check back soon!

- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)
22 changes: 20 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2013,8 +2013,24 @@ end
# This function assumes that the last dimension of each element is the batch dimension by
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
# XXX: Mutation inside a batched function is not supported yet (need to set the results
# correctly)
"""
batch(f, args...; batch_dims=nothing, result_dims=nothing)
Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
familiar with `jax`, this operation corresponds to `jax.vmap`.)
If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.
To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.
## Examples
For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.
!!! danger
Mutation inside a batched function is not supported yet and will lead to unexpected results.
"""
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
batch_sizes = Int64[]
batching_dims = if batch_dims === nothing
Expand Down Expand Up @@ -2060,6 +2076,8 @@ end
end

return fmap(results, result_dims) do result, dim
@assert dim !== nothing "Result batch dimension cannot be `nothing`"

order = collect(Int64, 1:ndims(result))
order[dim] = 1
order[1] = dim
Expand Down
2 changes: 0 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ using Functors: @leaf
using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

using Functors: @leaf

export @allowscalar # re-exported from GPUArraysCore

# auxiliary types and functions
Expand Down
2 changes: 2 additions & 0 deletions test/batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using Reactant, Test

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Sorting" include("sorting.jl")
@safetestset "Batching" include("batching.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
Expand Down

0 comments on commit 4c24ec2

Please sign in to comment.