Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tracing Random.jl functionality correctly #363

Merged
merged 22 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ jobs:
version: '1.10'
assertions: true
test_group: neural_networks
- os: ubuntu-20.04
arch: x64
libReactant: packaged
version: '1.10'
assertions: true
test_group: integration
- os: ubuntu-20.04
arch: x86
libReactant: packaged
Expand Down
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Expand All @@ -23,17 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

Expand All @@ -50,6 +53,8 @@ LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pages = [
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
"Internal API" => "api/internal.md",
],
]

Expand Down
4 changes: 3 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ export default defineConfig({
{ text: "MLIR API", link: "/api/mlirc" },
{ text: "XLA", link: "/api/xla" },
],
}
},
{ text: "Internal API", link: "/api/internal" },
],
},
{
Expand Down Expand Up @@ -132,6 +133,7 @@ export default defineConfig({
{ text: "XLA", link: "/api/xla" },
],
},
{ text: "Internal API", link: "/api/internal" },
],
},
},
Expand Down
12 changes: 12 additions & 0 deletions docs/src/api/internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# Internal API

These functions are not part of the public API and are subject to change at any time.

```@docs
Reactant.REDUB_ARGUMENTS_NAME
Reactant.within_reactant_interpreter
```
11 changes: 11 additions & 0 deletions ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ReactantRandom123Ext

using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"

end
141 changes: 136 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,19 +1016,150 @@ end
end

# random ops
"""
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from a uniform random
distribution between 0 and 1. Returns a NamedTuple with the following fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
)
output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape)
) where {T<:Integer}
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
if algorithm == "PHILOX"
@assert length(seed) ∈ (2, 3)
elseif algorithm == "THREE_FRY"
@assert length(seed) == 2
end

output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64))
rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm)
op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location)
op = stablehlo.rng_bit_generator(
seed.mlir_data; output, output_state, rng_algorithm, location
)
return (;
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape),
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)),
)
end

@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:AbstractFloat}
nbits = sizeof(T) * 8
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
)
return (; output_state, output)
end

"""
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from a standard normal
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
return (; output_state=seed, output=rand_normal)
end

"""
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from an exponential
distribution with rate 1. Returns a NamedTuple with the following fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
rand_exp = negate(log_plus_one(negate(rand_uniform)))
return (; output_state=seed, output=rand_exp)
end

# functional ops
Expand Down
68 changes: 67 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Helper Function to determine if we are inside the ReactantInterpreter
"""
within_reactant_interpreter()

Returns `true` if we are currently inside the ReactantInterpreter.
"""
@noinline within_reactant_interpreter() = false
@reactant_overlay @noinline within_reactant_interpreter() = true

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
# Enzyme.jl overlays
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
Expand All @@ -22,3 +31,60 @@ end
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

# Random.jl overlays
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end

## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
## We can't directly overlay that call without breaking the semantics of inplace update
for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:overload_, randfun)
overload_randfun! = Symbol(:overload_, randfun!)

@eval begin
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T <: ReactantPrimitive}
return TracedRandom.$(overload_randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, dim1::Integer, dims::Integer...
)
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T <: ReactantPrimitive}
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T <: ReactantPrimitive}
return TracedRandom.$(overload_randfun)(rng, T)
end

# inplace
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AnyTracedRArray
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# warn about direct writing to arrays
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AbstractArray
)
@warn "Directly writing to an array using Random.jl functions inside \
ReactantInterpreter will generate a constant array in the IR. Use with \
caution." maxlog = 1
return Random.$(randfun!)(rng, A)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses my understanding was that this should call the non-overlayed version. But here I get

┌ Warning: Directly writing to an array using Random.jl functions inside ReactantInterpreter will generate a constant array in the IR. Use with caution.
└ @ Reactant /mnt/software/lux/Reactant.jl/src/Overlay.jl:84
Unreachable reached at 0x7a64877b6b16

[1417783] signal 4 (2): Illegal instruction
in expression starting at REPL[7]:1
fn2 at ./REPL[6]:4 [inlined]
opaque closure at ./<missing>:0
unknown function (ip: 0x7a64877b6bff)
fn2 at ./REPL[6]:2 [inlined]
call_with_reactant at /mnt/software/lux/Reactant.jl/src/utils.jl:0
#8 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:210
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
unknown function (ip: 0x7a64877b6316)
#make_mlir_fn#1 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:197
make_mlir_fn at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:117 [inlined]
#10 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:295 [inlined]
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
#9 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:294 [inlined]
mmodule! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
unknown function (ip: 0x7a64877b5976)
#compile_mlir!#8 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:291
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:290 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x7a64877b2a76)
#compile_mlir#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:280
unknown function (ip: 0x7a64877b0a66)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10705 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_15174 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73406.1 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x7a65341e2e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 48616895 (Pool: 48615481; Big: 1414); GC: 47
[1]    1417783 illegal hardware instruction (core dumped)  julia --project=envs --threads=4 --check-bounds=yes

Copy link
Collaborator Author

@avik-pal avik-pal Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using this trick #369 (comment) seems to work correctly (though this has other weird edge-cases)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a general issue with any kind of recursion of an overlayed function. Do we have a way to force using the NativeInterpreter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ccing @gbaraldi and @aviatesk

end
end
end
11 changes: 10 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module Reactant
using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

Expand Down Expand Up @@ -122,7 +124,14 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")

include("linear_algebra.jl")
mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
File renamed without changes.
Loading
Loading