From 4c24ec21a64dacf853d52077371310425aadd14f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 17:45:11 -0500 Subject: [PATCH] docs: setup batching tutorial --- docs/make.jl | 10 +++++++++- docs/src/.vitepress/config.mts | 15 ++++++++++++++- docs/src/tutorials/batching.md | 3 +++ docs/src/tutorials/index.md | 2 ++ src/Ops.jl | 22 ++++++++++++++++++++-- src/Reactant.jl | 2 -- test/batching.jl | 2 ++ test/runtests.jl | 1 + 8 files changed, 51 insertions(+), 6 deletions(-) create mode 100644 docs/src/tutorials/batching.md create mode 100644 test/batching.jl diff --git a/docs/make.jl b/docs/make.jl index 603701f5b..f14c180bd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,10 @@ examples = [ pages = [ "Reactant.jl" => "index.md", "Introduction" => ["Getting Started" => "introduction/index.md"], - "Tutorials" => ["Overview" => "tutorials/index.md"], + "Tutorials" => [ + "Overview" => "tutorials/index.md", + "Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md", + ], "API Reference" => [ "Reactant API" => "api/api.md", "Ops" => "api/ops.md", @@ -37,6 +40,11 @@ pages = [ "Func" => "api/func.md", "StableHLO" => "api/stablehlo.md", "VHLO" => "api/vhlo.md", + "GPU" => "api/gpu.md", + "LLVM" => "api/llvm.md", + "NVVM" => "api/nvvm.md", + "TPU" => "api/tpu.md", + "Triton" => "api/triton.md", ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 5f881aa96..d731f9fc0 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -53,7 +53,16 @@ export default defineConfig({ { text: "Home", link: "/" }, { text: "Getting Started", link: "/introduction" }, { text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" }, - { text: "Tutorials", link: "/tutorials/" }, + { + text: "Tutorials", + items: [ + { text: "Overview", link: "/tutorials/" }, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching" + }, + ], + }, { text: "API", items: [ @@ -105,6 +114,10 @@ export default defineConfig({ collapsed: false, items: [ { text: "Overview", link: "/tutorials/" }, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching", + }, ], }, "/api/": { diff --git a/docs/src/tutorials/batching.md b/docs/src/tutorials/batching.md new file mode 100644 index 000000000..d3ca77884 --- /dev/null +++ b/docs/src/tutorials/batching.md @@ -0,0 +1,3 @@ +# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial) + + diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index eb2beb1f1..f9d01b86c 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,3 +1,5 @@ # Tutorials We are currently working on adding tutorials to Reactant!! Please check back soon! + +- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) diff --git a/src/Ops.jl b/src/Ops.jl index 50d3b2cbf..e001ce9fc 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2013,8 +2013,24 @@ end # This function assumes that the last dimension of each element is the batch dimension by # default. This is the standard Julia ordering for batching. We permutedims the ordering to # make sure the first dimension is the batch dimension when calling `batch_internal` below. -# XXX: Mutation inside a batched function is not supported yet (need to set the results -# correctly) +""" + batch(f, args...; batch_dims=nothing, result_dims=nothing) + +Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users +familiar with `jax`, this operation corresponds to `jax.vmap`.) + +If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension. + +To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`. + +## Examples + +For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial. + +!!! danger + + Mutation inside a batched function is not supported yet and will lead to unexpected results. +""" @noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) batch_sizes = Int64[] batching_dims = if batch_dims === nothing @@ -2060,6 +2076,8 @@ end end return fmap(results, result_dims) do result, dim + @assert dim !== nothing "Result batch dimension cannot be `nothing`" + order = collect(Int64, 1:ndims(result)) order[dim] = 1 order[1] = dim diff --git a/src/Reactant.jl b/src/Reactant.jl index 2418381a1..38f65f6b1 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -9,8 +9,6 @@ using Functors: @leaf using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` -using Functors: @leaf - export @allowscalar # re-exported from GPUArraysCore # auxiliary types and functions diff --git a/test/batching.jl b/test/batching.jl new file mode 100644 index 000000000..cd6ae6bbf --- /dev/null +++ b/test/batching.jl @@ -0,0 +1,2 @@ +using Reactant, Test + diff --git a/test/runtests.jl b/test/runtests.jl index 7d188fe3d..290e885fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Batching" include("batching.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"