Skip to content

Commit

Permalink
docs: setup batching tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 20, 2025
1 parent 6e684c3 commit 02bdfb1
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 9 deletions.
12 changes: 10 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ examples = [
pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"Tutorials" => [
"Overview" => "tutorials/index.md",
"Profiling" => "tutorials/profiling.md",
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
Expand All @@ -38,6 +41,11 @@ pages = [
"Func" => "api/func.md",
"StableHLO" => "api/stablehlo.md",
"VHLO" => "api/vhlo.md",
"GPU" => "api/gpu.md",
"LLVM" => "api/llvm.md",
"NVVM" => "api/nvvm.md",
"TPU" => "api/tpu.md",
"Triton" => "api/triton.md",
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
Expand Down
10 changes: 9 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ export default defineConfig({
{
text: "Tutorials",
items: [
{text: "Overview", link: "/tutorials/"},
{ text: "Overview", link: "/tutorials/" },
{text: "Profiling", link: "/tutorials/profiling"},
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching"
},
],
},
{
Expand Down Expand Up @@ -112,6 +116,10 @@ export default defineConfig({
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching",
},
],
},
"/api/": {
Expand Down
3 changes: 3 additions & 0 deletions docs/src/tutorials/batching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)


1 change: 1 addition & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tutorials

- [Profiling](@ref profiling).
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)

We are currently working on adding more tutorials to Reactant!! Please check back soon!
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ function codegen_unflatten!(
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
length(p) 1 && (p[1] == :result || p[1] == :resargs)
length(p) > 0 && (p[1] == :result || p[1] == :resargs)
)...,
)
for path in paths
Expand Down Expand Up @@ -846,7 +846,7 @@ function codegen_unflatten!(
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
length(p) 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
)...,
)

Expand Down
22 changes: 20 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2013,8 +2013,24 @@ end
# This function assumes that the last dimension of each element is the batch dimension by
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
# XXX: Mutation inside a batched function is not supported yet (need to set the results
# correctly)
"""
batch(f, args...; batch_dims=nothing, result_dims=nothing)
Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
familiar with `jax`, this operation corresponds to `jax.vmap`.)
If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.
To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.
## Examples
For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.
!!! danger
Mutation inside a batched function is not supported yet and will lead to unexpected results.
"""
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
batch_sizes = Int64[]
batching_dims = if batch_dims === nothing
Expand Down Expand Up @@ -2060,6 +2076,8 @@ end
end

return fmap(results, result_dims) do result, dim
@assert dim !== nothing "Result batch dimension cannot be `nothing`"

order = collect(Int64, 1:ndims(result))
order[dim] = 1
order[1] = dim
Expand Down
2 changes: 0 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ using Functors: @leaf
using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

using Functors: @leaf

export @allowscalar # re-exported from GPUArraysCore

# auxiliary types and functions
Expand Down
2 changes: 2 additions & 0 deletions test/batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using Reactant, Test

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Sorting" include("sorting.jl")
@safetestset "Batching" include("batching.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
Expand Down

0 comments on commit 02bdfb1

Please sign in to comment.