From 5b40fdf6b5be1408a862dbdb3ca07fa957df79ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 09:27:03 +0530 Subject: [PATCH 01/22] refactor: move stdlib overloads to a different directory --- Project.toml | 6 ++++-- src/Reactant.jl | 6 +++++- src/{linear_algebra.jl => stdlibs/LinearAlgebra.jl} | 0 src/stdlibs/Random.jl | 0 4 files changed, 9 insertions(+), 3 deletions(-) rename src/{linear_algebra.jl => stdlibs/LinearAlgebra.jl} (100%) create mode 100644 src/stdlibs/Random.jl diff --git a/Project.toml b/Project.toml index 9a1277d9d..5e3dceff0 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Scratch = "6c6a2e73-6563-6170-7368-637461726353" @@ -26,8 +27,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" +[sources] +ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" @@ -50,6 +51,7 @@ LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" +Random = "1.10" ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" diff --git a/src/Reactant.jl b/src/Reactant.jl index e7c8805de..29c17a056 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,6 +3,8 @@ module Reactant using ReactantCore: ReactantCore, @trace, MissingTracedValue using LinearAlgebra: LinearAlgebra +using Random: Random + using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` @@ -122,7 +124,9 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -include("linear_algebra.jl") +# StdLib Overloads +include("stdlibs/LinearAlgebra.jl") +include("stdlibs/Random.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/linear_algebra.jl b/src/stdlibs/LinearAlgebra.jl similarity index 100% rename from src/linear_algebra.jl rename to src/stdlibs/LinearAlgebra.jl diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl new file mode 100644 index 000000000..e69de29bb From a584127d5ce2b1890c92a87c9cd815e2df8679bf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 09:59:59 +0530 Subject: [PATCH 02/22] fix: Ops.rng_bit_generator --- src/Ops.jl | 21 ++++++++++++++++----- src/stdlibs/Random.jl | 4 ++++ test/ops.jl | 34 ++++++++++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index fa8c17b3c..04adfbd6d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1017,17 +1017,28 @@ end # random ops @noinline function rng_bit_generator( + ::Type{T}, seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) - output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape) +) where {T<:Integer} + @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") + if algorithm == "PHILOX" + @assert length(seed) ∈ (2, 3) + elseif algorithm == "THREE_FRY" + @assert length(seed) == 2 + end + + output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) + output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64)) rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm) - op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location) + op = stablehlo.rng_bit_generator( + seed.mlir_data; output, output_state, rng_algorithm, location + ) return (; - output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)), - output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape), + output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)), + output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)), ) end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index e69de29bb..36a91707a 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -0,0 +1,4 @@ +# Implementation based on the following: +# 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl +# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125 + diff --git a/test/ops.jl b/test/ops.jl index 07f911e88..4e4816728 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -538,8 +538,38 @@ end end @testset "rng_bit_generator" begin - # seed = ConcreteRArray([0, 0]) - # @jit Ops.rng_bit_generator(seed, [2]) + genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4]) + genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) + genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) + + @testset for (alg, sz) in [ + ("DEFAULT", 2), + ("PHILOX", 2), + ("PHILOX", 3), + ("THREE_FRY", 2), + ] + seed = ConcreteRArray(zeros(UInt64, sz)) + + res = @jit genInt32(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int32,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genInt64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int64,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genUInt64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{UInt64,2} + @test size(res.output) == (2, 4) + end end @testset "round_nearest_afz" begin From 86991318b9d4a192fb408084ef8dff4dda81a52f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 13:29:23 +0530 Subject: [PATCH 03/22] feat: initial prototype for random number generation --- src/Interpreter.jl | 12 ++++++ src/Ops.jl | 11 ++++- src/stdlibs/Random.jl | 93 +++++++++++++++++++++++++++++++++++++++++++ test/ops.jl | 8 +--- 4 files changed, 116 insertions(+), 8 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 4b71a1341..f8c124714 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -46,6 +46,18 @@ function set_reactant_abi( return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) end + # ensures we are not generating a constant array in the trace + # https://github.com/EnzymeAD/Reactant.jl/issues/356 + if (f === Random.default_rng || f === default_rng) && length(argtypes) == 1 + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : Any[:($(default_rng_inside_interpreter))], + Any[Core.Const(default_rng_inside_interpreter)], + ) + return abstract_call_known( + interp, default_rng_inside_interpreter, arginfo2, si, sv, max_methods + ) + end + return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/src/Ops.jl b/src/Ops.jl index 04adfbd6d..475d6a5f4 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -4,7 +4,14 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme -using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue +using ..Reactant: + Reactant, + TracedRArray, + TracedRNumber, + RArray, + RNumber, + MissingTracedValue, + ReactantPrimitive function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) @@ -1022,7 +1029,7 @@ end shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:Integer} +) where {T<:ReactantPrimitive} @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") if algorithm == "PHILOX" @assert length(seed) ∈ (2, 3) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 36a91707a..a2651d1fd 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -2,3 +2,96 @@ # 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl # 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125 +mutable struct TracedRNG <: Random.AbstractRNG + seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} + const algorithm::String +end + +# TODO: Base.seed! + +make_seed() = rand(Random.RandomDevice(), UInt64, 2) + +TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) +TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") + +default_rng() = TracedRNG() +function default_rng_inside_interpreter() + return TracedRNG(promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") +end + +# XXX: Currently we get an illegal instruction if we don't call Random.default_rng() + +# TODO: scalar rand functions should return a TracedRNumber + +# TODO: Implement `randexp` +# TODO: Implement `randexp!` + +function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} + length(A) == 0 && return A + res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + set_mlir_data!(A, res.output.mlir_data) + return A +end + +function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} + length(A) == 0 && return A + Random.rand!(rng, A) + scaled_uniform = Ops.subtract( + Ops.multiply(A, Ops.constant(fill(T(2), size(A)))), + Ops.constant(fill(T(1), size(A))), + ) + probit = Ops.erf_inv(scaled_uniform) + rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A)))) + set_mlir_data!(A, rand_normal.mlir_data) + return A +end + +function Random.rand(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} + return Random.rand!(rng, TracedRArray{T,length(dims)}((), nothing, dims)) +end +function Random.randn(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} + return Random.randn!(rng, TracedRArray{T,length(dims)}((), nothing, dims)) +end + +function Random.rand(rng::TracedRNG, dim1::Integer, dims::Integer...) + return Random.rand(rng, Dims((dim1, dims...))) +end +function Random.randn(rng::TracedRNG, dim1::Integer, dims::Integer...) + return Random.randn(rng, Dims((dim1, dims...))) +end + +function Random.rand(rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} + return Random.rand(rng, T, Dims((dim1, dims...))) +end +function Random.randn(rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} + return Random.randn(rng, T, Dims((dim1, dims...))) +end + +# # CPU arrays +# function Random.rand!(rng::RNG, A::AbstractArray{T}) where {T} +# B = CuArray{T}(undef, size(A)) +# rand!(rng, B) +# copyto!(A, B) +# end +# function Random.randn!(rng::RNG, A::AbstractArray{T}) where {T} +# B = CuArray{T}(undef, size(A)) +# randn!(rng, B) +# copyto!(A, B) +# end + +# # scalars +# Random.rand(rng::RNG, T::Type=Float32) = Random.rand(rng, T, 1)[] +# Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[] + +# # resolve ambiguities +# Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[] + +Random.rand!(A::AnyTracedRArray) = Random.rand!(default_rng(), A) +Random.randn!(A::AnyTracedRArray) = Random.randn!(default_rng(), A) + +# TODO: At some later point we might want to implement the sampler API as well since it +# makes all RNG implementation work by default. From the post-optimize IR we need to +# confirm that the dynamic_update_slice calls are optimized away into a single +# `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on +# how the seeding should work? diff --git a/test/ops.jl b/test/ops.jl index 4e4816728..0a17086b2 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -542,12 +542,8 @@ end genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) - @testset for (alg, sz) in [ - ("DEFAULT", 2), - ("PHILOX", 2), - ("PHILOX", 3), - ("THREE_FRY", 2), - ] + @testset for (alg, sz) in + [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] seed = ConcreteRArray(zeros(UInt64, sz)) res = @jit genInt32(seed) From 9da77c7ca49c1c514462d73cdbaf85c3e9d130b6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 18:29:43 +0530 Subject: [PATCH 04/22] feat: add support for scalar sampling --- src/stdlibs/Random.jl | 63 +++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index a2651d1fd..bfec4380e 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -21,11 +21,6 @@ end # XXX: Currently we get an illegal instruction if we don't call Random.default_rng() -# TODO: scalar rand functions should return a TracedRNumber - -# TODO: Implement `randexp` -# TODO: Implement `randexp!` - function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) @@ -47,25 +42,39 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} return A end -function Random.rand(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} - return Random.rand!(rng, TracedRArray{T,length(dims)}((), nothing, dims)) -end -function Random.randn(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} - return Random.randn!(rng, TracedRArray{T,length(dims)}((), nothing, dims)) -end - -function Random.rand(rng::TracedRNG, dim1::Integer, dims::Integer...) - return Random.rand(rng, Dims((dim1, dims...))) -end -function Random.randn(rng::TracedRNG, dim1::Integer, dims::Integer...) - return Random.randn(rng, Dims((dim1, dims...))) +for randfun in (:rand, :randn) + randfun! = Symbol(randfun, :!) + @eval begin + function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} + return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims)) + end + + function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) + return Random.$(randfun)(rng, Dims((dim1, dims...))) + end + + function Random.$(randfun)( + rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... + ) where {T} + return Random.$(randfun)(rng, T, Dims((dim1, dims...))) + end + + Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A) + + # scalars + function Random.$(randfun)(rng::TracedRNG, ::Type{T} = Float64) where {T} + A = promote_to(TracedRArray{T,0}, fill(T(0))) + Random.$(randfun!)(rng, A) + return A[] + end + end end -function Random.rand(rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} - return Random.rand(rng, T, Dims((dim1, dims...))) -end -function Random.randn(rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} - return Random.randn(rng, T, Dims((dim1, dims...))) +# resolve ambiguities +function Random.randn(rng::TracedRNG, T::Random.BitFloatType) + A = promote_to(TracedRArray{T,0}, fill(T(0))) + Random.randn!(rng, A) + return A[] end # # CPU arrays @@ -80,16 +89,6 @@ end # copyto!(A, B) # end -# # scalars -# Random.rand(rng::RNG, T::Type=Float32) = Random.rand(rng, T, 1)[] -# Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[] - -# # resolve ambiguities -# Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[] - -Random.rand!(A::AnyTracedRArray) = Random.rand!(default_rng(), A) -Random.randn!(A::AnyTracedRArray) = Random.randn!(default_rng(), A) - # TODO: At some later point we might want to implement the sampler API as well since it # makes all RNG implementation work by default. From the post-optimize IR we need to # confirm that the dynamic_update_slice calls are optimized away into a single From 6dcad12ee9316df5feeba879f3f7f70903f8ac54 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 18:45:01 +0530 Subject: [PATCH 05/22] feat: efficient sampling for non-native RNGs --- src/stdlibs/Random.jl | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index bfec4380e..fccc583bf 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -7,7 +7,18 @@ mutable struct TracedRNG <: Random.AbstractRNG const algorithm::String end -# TODO: Base.seed! +function Random.seed!(rng::TracedRNG, seed::Number) + seed = reinterpret(UInt64, Random.hash_seed(seed)) + # TODO: Using `seed!` inside tracing should generate a TracedRArray + return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)])) +end + +function Random.seed!( + rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} +) + rng.seed = seed + return rng +end make_seed() = rand(Random.RandomDevice(), UInt64, 2) @@ -21,6 +32,8 @@ end # XXX: Currently we get an illegal instruction if we don't call Random.default_rng() +# XXX: rng_bit_generator doesn't support floating point types + function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) @@ -62,11 +75,21 @@ for randfun in (:rand, :randn) Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A) # scalars - function Random.$(randfun)(rng::TracedRNG, ::Type{T} = Float64) where {T} + function Random.$(randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} A = promote_to(TracedRArray{T,0}, fill(T(0))) Random.$(randfun!)(rng, A) return A[] end + + # Non-Traced RNGs if used will lead to disastrous performance. We attempt to fix + # that but with a warning + function Random.$(randfun!)(rng::Random.AbstractRNG, A::AnyTracedRArray) + @warn "`rng` is not a `TracedRNG`. We will use this to seed the `TracedRNG` \ + instead of generating samples from this RNG type." maxlog = 1 + seed = promote_to(TracedRArray{UInt64,1}, rand(rng, UInt64, 2)) + trng = TracedRNG(seed, "DEFAULT") + return Random.$(randfun!)(trng, A) + end end end @@ -77,18 +100,6 @@ function Random.randn(rng::TracedRNG, T::Random.BitFloatType) return A[] end -# # CPU arrays -# function Random.rand!(rng::RNG, A::AbstractArray{T}) where {T} -# B = CuArray{T}(undef, size(A)) -# rand!(rng, B) -# copyto!(A, B) -# end -# function Random.randn!(rng::RNG, A::AbstractArray{T}) where {T} -# B = CuArray{T}(undef, size(A)) -# randn!(rng, B) -# copyto!(A, B) -# end - # TODO: At some later point we might want to implement the sampler API as well since it # makes all RNG implementation work by default. From the post-optimize IR we need to # confirm that the dynamic_update_slice calls are optimized away into a single From bbf2a48699cf45655994b559614fb58e0e9917b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 19:24:40 +0530 Subject: [PATCH 06/22] fix: handling floating point sampling --- src/Ops.jl | 28 +++++++++++++++++++--------- src/stdlibs/Random.jl | 6 ++++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 475d6a5f4..ba79f5cb9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -4,14 +4,7 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme -using ..Reactant: - Reactant, - TracedRArray, - TracedRNumber, - RArray, - RNumber, - MissingTracedValue, - ReactantPrimitive +using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) @@ -1029,7 +1022,7 @@ end shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:ReactantPrimitive} +) where {T<:Integer} @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") if algorithm == "PHILOX" @assert length(seed) ∈ (2, 3) @@ -1049,6 +1042,23 @@ end ) end +function rng_bit_generator( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), +) where {T<:Union{Float16,Float32,Float64}} + nbits = sizeof(T) * 8 + uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64) + (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) + output = divide( + convert(TracedRArray{T,ndims(output)}, output), + constant(fill(T(typemax(uT)), Tuple(shape)); location), + ) + return (; output_state, output) +end + # functional ops @noinline function return_( results::Union{TracedRArray,TracedRNumber}...; diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index fccc583bf..090b72bb3 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -32,8 +32,6 @@ end # XXX: Currently we get an illegal instruction if we don't call Random.default_rng() -# XXX: rng_bit_generator doesn't support floating point types - function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) @@ -62,6 +60,10 @@ for randfun in (:rand, :randn) return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims)) end + function Random.$(randfun)(rng::TracedRNG, dims::Dims) + return Random.$(randfun)(rng, Float64, dims) + end + function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) return Random.$(randfun)(rng, Dims((dim1, dims...))) end From e82c2c0f4ddc7e40ed1c84c0c51567649b88c197 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 08:59:18 +0530 Subject: [PATCH 07/22] feat: use the override macro --- src/Interpreter.jl | 12 ------------ src/stdlibs/Random.jl | 9 +++++---- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index f8c124714..4b71a1341 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -46,18 +46,6 @@ function set_reactant_abi( return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) end - # ensures we are not generating a constant array in the trace - # https://github.com/EnzymeAD/Reactant.jl/issues/356 - if (f === Random.default_rng || f === default_rng) && length(argtypes) == 1 - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : Any[:($(default_rng_inside_interpreter))], - Any[Core.Const(default_rng_inside_interpreter)], - ) - return abstract_call_known( - interp, default_rng_inside_interpreter, arginfo2, si, sv, max_methods - ) - end - return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 090b72bb3..656d61383 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -27,16 +27,17 @@ TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") default_rng() = TracedRNG() function default_rng_inside_interpreter() - return TracedRNG(promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") + return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -# XXX: Currently we get an illegal instruction if we don't call Random.default_rng() +@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter() +@reactant_override @noinline default_rng() = default_rng_inside_interpreter() function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state - set_mlir_data!(A, res.output.mlir_data) + TracedUtils.set_mlir_data!(A, res.output.mlir_data) return A end @@ -49,7 +50,7 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} ) probit = Ops.erf_inv(scaled_uniform) rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A)))) - set_mlir_data!(A, rand_normal.mlir_data) + TracedUtils.set_mlir_data!(A, rand_normal.mlir_data) return A end From 4b7c504892e398ccac66e8929520cbb3a00bd8a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 09:50:36 +0530 Subject: [PATCH 08/22] fix: use `@noinline` --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index ba79f5cb9..96a537e66 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1042,7 +1042,7 @@ end ) end -function rng_bit_generator( +@noinline function rng_bit_generator( ::Type{T}, seed::TracedRArray{UInt64,1}, shape; From 7d73538cafb2f0ac2ed45e2c28906883ac01587c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 10:58:08 +0530 Subject: [PATCH 09/22] feat: support randexp --- src/stdlibs/Random.jl | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 656d61383..0b44c93b9 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -54,7 +54,19 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} return A end -for randfun in (:rand, :randn) +function Random.randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} + length(A) == 0 && return A + Random.rand!(rng, A) + TracedUtils.set_mlir_data!( + A, + Ops.negate( + Ops.log_plus_one(Ops.negate(TracedUtils.materialize_traced_array(A))) + ).mlir_data, + ) + return A +end + +for randfun in (:rand, :randn, :randexp) randfun! = Symbol(randfun, :!) @eval begin function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} @@ -97,10 +109,12 @@ for randfun in (:rand, :randn) end # resolve ambiguities -function Random.randn(rng::TracedRNG, T::Random.BitFloatType) - A = promote_to(TracedRArray{T,0}, fill(T(0))) - Random.randn!(rng, A) - return A[] +for randfun in (:randn, :randexp) + @eval function Random.$(randfun)(rng::TracedRNG, T::Random.BitFloatType) + A = promote_to(TracedRArray{T,0}, fill(T(0))) + Random.randn!(rng, A) + return A[] + end end # TODO: At some later point we might want to implement the sampler API as well since it From 591320c723a7f1d5a620cbc9b3d132044a2e80d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 11:01:31 +0530 Subject: [PATCH 10/22] feat: override seeding inside interpreter --- src/stdlibs/Random.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 0b44c93b9..954bd29e5 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -9,10 +9,16 @@ end function Random.seed!(rng::TracedRNG, seed::Number) seed = reinterpret(UInt64, Random.hash_seed(seed)) - # TODO: Using `seed!` inside tracing should generate a TracedRArray return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)])) end +@reactant_override @noinline function Random.seed!(rng::TracedRNG, seed::Number) + seed = reinterpret(UInt64, Random.hash_seed(seed)) + return Random.seed!( + rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) + ) +end + function Random.seed!( rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} ) From 1b45638549f4eadd7dda8801c6f72533108f660b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 11:14:59 +0530 Subject: [PATCH 11/22] refactor: move things into a module --- src/Reactant.jl | 5 +++++ src/stdlibs/Random.jl | 20 ++++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 29c17a056..2441fa50e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -124,6 +124,11 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") +mutable struct TracedRNG <: Random.AbstractRNG + seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} + const algorithm::String +end + # StdLib Overloads include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 954bd29e5..eb652b16c 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -1,11 +1,21 @@ +module TracedRandom + # Implementation based on the following: # 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl # 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125 -mutable struct TracedRNG <: Random.AbstractRNG - seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} - const algorithm::String -end +using ..Reactant: + Reactant, + TracedRArray, + TracedRNumber, + TracedRNG, + AnyTracedRArray, + Reactant, + TracedUtils, + @reactant_override, + Ops, + ConcreteRArray +using Random: Random function Random.seed!(rng::TracedRNG, seed::Number) seed = reinterpret(UInt64, Random.hash_seed(seed)) @@ -128,3 +138,5 @@ end # confirm that the dynamic_update_slice calls are optimized away into a single # `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on # how the seeding should work? + +end From 51f91e49f17e8dbae8029991bb938c804113a98d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Dec 2024 22:55:37 +0530 Subject: [PATCH 12/22] refactor: rework how the overlays are implemented --- src/Overlay.jl | 16 +++++++++++++++- src/stdlibs/Random.jl | 22 ++++++++-------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 6d4752acd..18da97f94 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -3,6 +3,15 @@ # correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved # we should move all the reactant_overrides to relevant files. +# Helper Function to determine if we are inside the ReactantInterpreter +""" + within_reactant_interpreter() + +Returns `true` if we are currently inside the ReactantInterpreter. +""" +@noinline within_reactant_interpreter() = false +@reactant_overlay @noinline within_reactant_interpreter() = true + # Compiling within a compile should return simply the original function @reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false @@ -10,7 +19,7 @@ return f end -# Enzyme overrides +# Enzyme.jl overlays @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} @@ -22,3 +31,8 @@ end ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end + +# Random.jl overlays +@reactant_overlay @noinline function Random.default_rng() + return call_with_reactant(TracedRandom.default_rng) +end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index eb652b16c..d5bda428f 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -12,21 +12,18 @@ using ..Reactant: AnyTracedRArray, Reactant, TracedUtils, - @reactant_override, Ops, ConcreteRArray using Random: Random function Random.seed!(rng::TracedRNG, seed::Number) seed = reinterpret(UInt64, Random.hash_seed(seed)) - return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)])) -end - -@reactant_override @noinline function Random.seed!(rng::TracedRNG, seed::Number) - seed = reinterpret(UInt64, Random.hash_seed(seed)) - return Random.seed!( - rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) - ) + seed = if Reactant.within_reactant_interpreter() + TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) + else + ConcreteRArray(seed[1:length(rng.seed)]) + end + return Random.seed!(rng, seed) end function Random.seed!( @@ -41,14 +38,11 @@ make_seed() = rand(Random.RandomDevice(), UInt64, 2) TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") -default_rng() = TracedRNG() -function default_rng_inside_interpreter() +function default_rng() + Reactant.within_reactant_interpreter() || return TracedRNG() return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter() -@reactant_override @noinline default_rng() = default_rng_inside_interpreter() - function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) From 5a94f1440533b50876c8b6f49849043e4efcb062 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 10:18:27 +0530 Subject: [PATCH 13/22] docs: add internal api to the docs --- docs/make.jl | 1 + docs/src/.vitepress/config.mts | 4 +++- docs/src/api/internal.md | 12 ++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 docs/src/api/internal.md diff --git a/docs/make.jl b/docs/make.jl index 7515a566d..fcbaca60e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -43,6 +43,7 @@ pages = [ ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", + "Internal API" => "api/internal.md", ], ] diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 942a9415d..1dc25f2ad 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -78,7 +78,8 @@ export default defineConfig({ { text: "MLIR API", link: "/api/mlirc" }, { text: "XLA", link: "/api/xla" }, ], - } + }, + { text: "Internal API", link: "/api/internal" }, ], }, { @@ -132,6 +133,7 @@ export default defineConfig({ { text: "XLA", link: "/api/xla" }, ], }, + { text: "Internal API", link: "/api/internal" }, ], }, }, diff --git a/docs/src/api/internal.md b/docs/src/api/internal.md new file mode 100644 index 000000000..a8788e5fb --- /dev/null +++ b/docs/src/api/internal.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# Internal API + +These functions are not part of the public API and are subject to change at any time. + +```@docs +Reactant.REDUB_ARGUMENTS_NAME +Reactant.within_reactant_interpreter +``` From d61d2696379433c1f20bf3542d5d430f7c2bc88b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 10:38:42 +0530 Subject: [PATCH 14/22] test: include floating point tests --- src/Ops.jl | 2 +- test/ops.jl | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index 96a537e66..799724ba7 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1048,7 +1048,7 @@ end shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:Union{Float16,Float32,Float64}} +) where {T<:AbstractFloat} nbits = sizeof(T) * 8 uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64) (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) diff --git a/test/ops.jl b/test/ops.jl index 0a17086b2..82ec4cc8b 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -541,6 +541,8 @@ end genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4]) genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) + genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4]) + genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4]) @testset for (alg, sz) in [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] @@ -565,6 +567,20 @@ end @test size(res.output_state) == (sz,) @test res.output isa ConcreteRArray{UInt64,2} @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genFloat32(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float32,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genFloat64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float64,2} + @test size(res.output) == (2, 4) end end From 36e56cb060eeadfc86fed08896a5b56ea9ef4669 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 11:16:37 +0530 Subject: [PATCH 15/22] test: setup testing --- test/integration/random.jl | 0 test/runtests.jl | 1 + 2 files changed, 1 insertion(+) create mode 100644 test/integration/random.jl diff --git a/test/integration/random.jl b/test/integration/random.jl new file mode 100644 index 000000000..e69de29bb diff --git a/test/runtests.jl b/test/runtests.jl index fddc963ce..68dfcaead 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") + @safetestset "Random" include("integration/random.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 6c21721d629c090309bbeac534d61214355616fb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 17:40:39 +0530 Subject: [PATCH 16/22] feat: overlay all generators --- Project.toml | 3 + ext/ReactantRandom123Ext.jl | 11 ++++ src/Ops.jl | 103 ++++++++++++++++++++++++++++++++ src/Overlay.jl | 52 ++++++++++++++++ src/Reactant.jl | 2 +- src/stdlibs/Random.jl | 116 +++++++++++++++++++++--------------- test/Project.toml | 2 + test/integration/random.jl | 12 ++++ 8 files changed, 252 insertions(+), 49 deletions(-) create mode 100644 ext/ReactantRandom123Ext.jl diff --git a/Project.toml b/Project.toml index 5e3dceff0..2ecce09b3 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -35,6 +36,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" +ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -52,6 +54,7 @@ NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" Random = "1.10" +Random123 = "1.7" ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl new file mode 100644 index 000000000..d701fdc7e --- /dev/null +++ b/ext/ReactantRandom123Ext.jl @@ -0,0 +1,11 @@ +module ReactantRandom123Ext + +using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x +using Reactant: TracedRandom + +TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" +TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" +TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" +TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" + +end diff --git a/src/Ops.jl b/src/Ops.jl index 799724ba7..18ab2d7d4 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1016,6 +1016,29 @@ end end # random ops +""" + rng_bit_generator( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Generate a random array of type `T` with the given shape and seed from a uniform random +distribution between 0 and 1. Returns a NamedTuple with the following fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" @noinline function rng_bit_generator( ::Type{T}, seed::TracedRArray{UInt64,1}, @@ -1059,6 +1082,86 @@ end return (; output_state, output) end +""" + randn( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Generate a random array of type `T` with the given shape and seed from a standard normal +distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following +fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" +@noinline function randn( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), +) where {T} + res = rng_bit_generator(T, seed, shape; algorithm, location) + rand_uniform = res.output + seed = res.output_state + scaled_uniform = subtract( + multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))), + constant(fill(T(1), size(rand_uniform))), + ) + probit = erf_inv(scaled_uniform) + rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform)))) + return (; output_state=seed, output=rand_normal) +end + +""" + randexp( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Generate a random array of type `T` with the given shape and seed from an exponential +distribution with rate 1. Returns a NamedTuple with the following fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" +@noinline function randexp( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), +) where {T} + res = rng_bit_generator(T, seed, shape; algorithm, location) + rand_uniform = res.output + seed = res.output_state + rand_exp = negate(log_plus_one(negate(rand_uniform))) + return (; output_state=seed, output=rand_exp) +end + # functional ops @noinline function return_( results::Union{TracedRArray,TracedRNumber}...; diff --git a/src/Overlay.jl b/src/Overlay.jl index 18da97f94..402976746 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -36,3 +36,55 @@ end @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) end + +## Only problematic edge case here is the direct `(rng, A::AbstractArray)` call +## We can't directly overlay that call without breaking the semantics of inplace update +for randfun in (:rand, :randn, :randexp) + randfun! = Symbol(randfun, :!) + overload_randfun = Symbol(:overload_, randfun) + overload_randfun! = Symbol(:overload_, randfun!) + + @eval begin + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}, dims::Dims + ) where {T} + return TracedRandom.$(overload_randfun)(rng, T, dims) + end + + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, dim1::Integer, dims::Integer... + ) + return TracedRandom.$(overload_randfun)(rng, dim1, dims...) + end + + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... + ) where {T} + return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) + end + + # scalars + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}=Float64 + ) where {T} + return TracedRandom.$(overload_randfun)(rng, T) + end + + # inplace + @reactant_overlay @noinline function Random.$(randfun!)( + rng::AbstractRNG, A::AnyTracedRArray + ) + return TracedRandom.$(overload_randfun!)(rng, A) + end + + # warn about direct writing to arrays + @reactant_overlay @noinline function Random.$(randfun!)( + rng::AbstractRNG, A::AbstractArray + ) + @warn "Directly writing to an array using Random.jl functions inside \ + ReactantInterpreter will generate a constant array in the IR. Use with \ + caution." maxlog = 1 + return Random.$(randfun!)(rng, A) + end + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 2441fa50e..bea015074 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,7 +3,7 @@ module Reactant using ReactantCore: ReactantCore, @trace, MissingTracedValue using LinearAlgebra: LinearAlgebra -using Random: Random +using Random: Random, AbstractRNG using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index d5bda428f..2c19427cd 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -14,9 +14,14 @@ using ..Reactant: TracedUtils, Ops, ConcreteRArray -using Random: Random +using Random: Random, AbstractRNG function Random.seed!(rng::TracedRNG, seed::Number) + if seed isa TracedRNumber + error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \ + `TracedRArray` of the appropriate size instead.") + end + seed = reinterpret(UInt64, Random.hash_seed(seed)) seed = if Reactant.within_reactant_interpreter() TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) @@ -26,6 +31,14 @@ function Random.seed!(rng::TracedRNG, seed::Number) return Random.seed!(rng, seed) end +function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1}) + return Random.seed!(rng, UInt64.(seed)) +end + +function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1}) + return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed)) +end + function Random.seed!( rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} ) @@ -43,7 +56,10 @@ function default_rng() return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +rng_algorithm(rng::TracedRNG) = rng.algorithm +rng_algorithm(::AbstractRNG) = "DEFAULT" + +function internal_overload_rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state @@ -51,79 +67,83 @@ function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} return A end -function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +function internal_overload_randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A - Random.rand!(rng, A) - scaled_uniform = Ops.subtract( - Ops.multiply(A, Ops.constant(fill(T(2), size(A)))), - Ops.constant(fill(T(1), size(A))), - ) - probit = Ops.erf_inv(scaled_uniform) - rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A)))) - TracedUtils.set_mlir_data!(A, rand_normal.mlir_data) + res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + TracedUtils.set_mlir_data!(A, res.output.mlir_data) return A end -function Random.randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +function internal_overload_randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A - Random.rand!(rng, A) - TracedUtils.set_mlir_data!( - A, - Ops.negate( - Ops.log_plus_one(Ops.negate(TracedUtils.materialize_traced_array(A))) - ).mlir_data, - ) + res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + TracedUtils.set_mlir_data!(A, res.output.mlir_data) return A end for randfun in (:rand, :randn, :randexp) randfun! = Symbol(randfun, :!) + overload_randfun = Symbol(:internal_overload_, randfun) + overload_randfun! = Symbol(:internal_overload_, randfun!) + @eval begin - function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} - return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims)) + function $(overload_randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} + return $(overload_randfun!)( + rng, TracedRArray{T,length(dims)}((), nothing, dims) + ) end - function Random.$(randfun)(rng::TracedRNG, dims::Dims) - return Random.$(randfun)(rng, Float64, dims) + function $(overload_randfun)(rng::TracedRNG, dims::Dims) + return $(overload_randfun)(rng, Float64, dims) end - function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) - return Random.$(randfun)(rng, Dims((dim1, dims...))) + function $(overload_randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) + return $(overload_randfun)(rng, Dims((dim1, dims...))) end - function Random.$(randfun)( + function $(overload_randfun)( rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} - return Random.$(randfun)(rng, T, Dims((dim1, dims...))) + return $(overload_randfun)(rng, T, Dims((dim1, dims...))) end - Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A) + $(overload_randfun!)(A::AnyTracedRArray) = $(overload_randfun!)(default_rng(), A) # scalars - function Random.$(randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} - A = promote_to(TracedRArray{T,0}, fill(T(0))) - Random.$(randfun!)(rng, A) - return A[] - end - - # Non-Traced RNGs if used will lead to disastrous performance. We attempt to fix - # that but with a warning - function Random.$(randfun!)(rng::Random.AbstractRNG, A::AnyTracedRArray) - @warn "`rng` is not a `TracedRNG`. We will use this to seed the `TracedRNG` \ - instead of generating samples from this RNG type." maxlog = 1 - seed = promote_to(TracedRArray{UInt64,1}, rand(rng, UInt64, 2)) - trng = TracedRNG(seed, "DEFAULT") - return Random.$(randfun!)(trng, A) + function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} + A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0))) + $(overload_randfun!)(rng, A) + return TracedRNumber{T}((), A.mlir_data) end end end -# resolve ambiguities -for randfun in (:randn, :randexp) - @eval function Random.$(randfun)(rng::TracedRNG, T::Random.BitFloatType) - A = promote_to(TracedRArray{T,0}, fill(T(0))) - Random.randn!(rng, A) - return A[] +# call from overlay-ed variants. we write this with 2 tiers -- overload_* and +# internal_overload_* -- to avoid method ambiguities +for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) + overload_randfun = Symbol(:overload_, randfun) + internal_overload_randfun = Symbol(:internal_overload_, randfun) + @eval begin + function $(overload_randfun)(rng::AbstractRNG, args...) + seed_uint64 = Array{UInt64}(undef, 2) + sampler = Random.Sampler(rng, UInt64, Val(1)) + seed_uint64[1] = rand(rng, sampler) + seed_uint64[2] = rand(rng, sampler) + # XXX: Ideally the following should just work but currently it gives an illegal + # instruction error. Maybe an issue with Julia's AbsInt? + # Random.rand!(rng, seed_uint64) + rng = TracedRNG( + TracedUtils.promote_to(TracedRArray{UInt64,1}, seed_uint64), + rng_algorithm(rng), + ) + return $(internal_overload_randfun)(rng, args...) + end + + function $(overload_randfun)(rng::TracedRNG, args...) + return $(internal_overload_randfun)(rng, args...) + end end end diff --git a/test/Project.toml b/test/Project.toml index 9b3c5a6b4..7d1b45daa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -19,6 +20,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/integration/random.jl b/test/integration/random.jl index e69de29bb..73645b32c 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -0,0 +1,12 @@ +using Reactant, Test, Random +using StatsBase, Statistics, HypothesisTests + +# First Testing overlay works correctly + +# Next we test that the random number generators actually generate data from the correct +# distributions +@testset "Uniform Random" begin end + +@testset "Normal Distribution" begin end + +@testset "Exponential Distribution" begin end From d0064f25b55090710608d554310d29c4d237059c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 18:43:35 +0530 Subject: [PATCH 17/22] test: ensure distributions are correct --- .github/workflows/CI.yml | 6 ++ test/Project.toml | 3 + test/integration/random.jl | 126 +++++++++++++++++++++++++++++++++++-- 3 files changed, 130 insertions(+), 5 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 66882fb6a..508ff06b9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,12 @@ jobs: version: '1.10' assertions: true test_group: neural_networks + - os: ubuntu-20.04 + arch: x64 + libReactant: packaged + version: '1.10' + assertions: true + test_group: integration - os: ubuntu-20.04 arch: x86 libReactant: packaged diff --git a/test/Project.toml b/test/Project.toml index 7d1b45daa..cb0ccc4f6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -16,9 +17,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration/random.jl b/test/integration/random.jl index 73645b32c..321810223 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -1,12 +1,128 @@ -using Reactant, Test, Random -using StatsBase, Statistics, HypothesisTests +using Reactant, Test, Random, Random123, StableRNGs, Statistics +using StatsBase, Statistics, HypothesisTests, Distributions # First Testing overlay works correctly # Next we test that the random number generators actually generate data from the correct # distributions -@testset "Uniform Random" begin end +@testset "Uniform Random" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) -@testset "Normal Distribution" begin end + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return rand(rng, 10000) + end -@testset "Exponential Distribution" begin end + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + ks_test = ExactOneSampleKSTest(X, Uniform(0.0, 1.0)) + @test pvalue(ks_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Correct Range" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + X_min, X_max = extrema(X) + @test X_min ≥ 0.0 + @test X_max ≤ 1.0 + end + + @testset "Mean & Variance" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + μ = mean(X) + σ² = var(X) + @test μ ≈ 0.5 atol = 0.05 rtol = 0.05 + @test σ² ≈ (1//12) atol = 0.05 rtol = 0.05 + end +end + +@testset "Normal Distribution" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) + + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return randn(rng, 10000) + end + + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + sw_test = ShapiroWilkTest(X) + @test pvalue(sw_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Mean & Variance" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + μ = mean(X) + σ² = var(X) + @test μ ≈ 0.0 atol = 0.05 rtol = 0.05 + @test σ² ≈ 1.0 atol = 0.05 rtol = 0.05 + end +end + +@testset "Exponential Distribution" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) + + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return randexp(rng, 10000) + end + + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + ks_test = ExactOneSampleKSTest(X, Exponential(1.0)) + @test pvalue(ks_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Correct Range" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + X_min, X_max = extrema(X) + @test X_min ≥ 0.0 + end + + @testset "Mean" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + μ = mean(X) + @test μ ≈ 1.0 atol = 0.05 rtol = 0.05 + end +end From 9129ce2ef1ac30a9cf5435e70f945c8b4f1f177a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 20:21:00 +0530 Subject: [PATCH 18/22] test: overlay generation --- src/Overlay.jl | 6 +++--- src/stdlibs/Random.jl | 20 ++++++++++-------- test/integration/random.jl | 42 ++++++++++++++++++++++++++++++++++++++ test/nn/lux.jl | 2 +- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 402976746..1bbb3cbe2 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -47,7 +47,7 @@ for randfun in (:rand, :randn, :randexp) @eval begin @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims - ) where {T} + ) where {T <: ReactantPrimitive} return TracedRandom.$(overload_randfun)(rng, T, dims) end @@ -59,14 +59,14 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... - ) where {T} + ) where {T <: ReactantPrimitive} return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end # scalars @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 - ) where {T} + ) where {T <: ReactantPrimitive} return TracedRandom.$(overload_randfun)(rng, T) end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 2c19427cd..fcc8a8d5e 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -2,7 +2,7 @@ module TracedRandom # Implementation based on the following: # 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl -# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125 +# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl using ..Reactant: Reactant, @@ -16,6 +16,14 @@ using ..Reactant: ConcreteRArray using Random: Random, AbstractRNG +function make_seed(rng::AbstractRNG = Random.RandomDevice()) + seed_uint64 = Array{UInt64}(undef, 2) + sampler = Random.Sampler(rng, UInt64, Val(1)) + seed_uint64[1] = rand(rng, sampler) + seed_uint64[2] = rand(rng, sampler) + return seed_uint64 +end + function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \ @@ -46,8 +54,6 @@ function Random.seed!( return rng end -make_seed() = rand(Random.RandomDevice(), UInt64, 2) - TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") @@ -127,15 +133,11 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) internal_overload_randfun = Symbol(:internal_overload_, randfun) @eval begin function $(overload_randfun)(rng::AbstractRNG, args...) - seed_uint64 = Array{UInt64}(undef, 2) - sampler = Random.Sampler(rng, UInt64, Val(1)) - seed_uint64[1] = rand(rng, sampler) - seed_uint64[2] = rand(rng, sampler) # XXX: Ideally the following should just work but currently it gives an illegal # instruction error. Maybe an issue with Julia's AbsInt? - # Random.rand!(rng, seed_uint64) + # seed_uint64 = rand(rng, UInt64, 2) rng = TracedRNG( - TracedUtils.promote_to(TracedRArray{UInt64,1}, seed_uint64), + TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)), rng_algorithm(rng), ) return $(internal_overload_randfun)(rng, args...) diff --git a/test/integration/random.jl b/test/integration/random.jl index 321810223..128b05954 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -2,6 +2,48 @@ using Reactant, Test, Random, Random123, StableRNGs, Statistics using StatsBase, Statistics, HypothesisTests, Distributions # First Testing overlay works correctly +@testset "Random.jl Overlay" begin + hlo = @code_hlo rand(Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float64, (2, 3)) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float64) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister()) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + fn(x) = begin + rng = MersenneTwister() + Random.rand!(rng, x) + return x + end + hlo = @code_hlo fn(Reactant.to_rarray(rand(Float64, 2, 3))) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + # XXX: This crashes with Unreachable reached at 0x7a64877b6b16 + # fn2() = begin + # rng = MersenneTwister() + # x = zeros(Float64, 2, 3) + # Random.rand!(rng, x) + # return x + # end + # hlo = @code_hlo fn2() + # @test !contains(repr(hlo), "stablehlo.rng_bit_generator") +end + +@testset "Random123" begin end # Next we test that the random number generators actually generate data from the correct # distributions diff --git a/test/nn/lux.jl b/test/nn/lux.jl index 49fa37f52..7916ce10f 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -8,7 +8,7 @@ end function gradient_loss_function(model, x, y, ps, st) dps = Enzyme.make_zero(ps) _, res = Enzyme.autodiff( - ReverseWithPrimal, + set_runtime_activity(ReverseWithPrimal), loss_function, Active, Const(model), From 867324e6de9618b8cd6eb06b7e00a6c4a27d5bec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 21:08:03 +0530 Subject: [PATCH 19/22] fix: test whether we can call into the non-overlayed version --- src/Overlay.jl | 37 +++++++++++++++++++++---------------- src/stdlibs/Random.jl | 11 +---------- test/integration/random.jl | 35 +++++++++++++++++++++++++---------- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 1bbb3cbe2..15cb5ba74 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -47,8 +47,13 @@ for randfun in (:rand, :randn, :randexp) @eval begin @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims - ) where {T <: ReactantPrimitive} - return TracedRandom.$(overload_randfun)(rng, T, dims) + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T, dims) + end + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Random.$(randfun)(rng, T, dims) end @reactant_overlay @noinline function Random.$(randfun)( @@ -59,15 +64,25 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... - ) where {T <: ReactantPrimitive} - return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) + end + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Random.$(randfun)(rng, T, dim1, dims...) end # scalars @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 - ) where {T <: ReactantPrimitive} - return TracedRandom.$(overload_randfun)(rng, T) + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T) + end + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Random.$(randfun)(rng, T) end # inplace @@ -76,15 +91,5 @@ for randfun in (:rand, :randn, :randexp) ) return TracedRandom.$(overload_randfun!)(rng, A) end - - # warn about direct writing to arrays - @reactant_overlay @noinline function Random.$(randfun!)( - rng::AbstractRNG, A::AbstractArray - ) - @warn "Directly writing to an array using Random.jl functions inside \ - ReactantInterpreter will generate a constant array in the IR. Use with \ - caution." maxlog = 1 - return Random.$(randfun!)(rng, A) - end end end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index fcc8a8d5e..dd2eb82fb 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -16,13 +16,7 @@ using ..Reactant: ConcreteRArray using Random: Random, AbstractRNG -function make_seed(rng::AbstractRNG = Random.RandomDevice()) - seed_uint64 = Array{UInt64}(undef, 2) - sampler = Random.Sampler(rng, UInt64, Val(1)) - seed_uint64[1] = rand(rng, sampler) - seed_uint64[2] = rand(rng, sampler) - return seed_uint64 -end +make_seed(rng::AbstractRNG=Random.RandomDevice()) = rand(rng, UInt64, 2) function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber @@ -133,9 +127,6 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) internal_overload_randfun = Symbol(:internal_overload_, randfun) @eval begin function $(overload_randfun)(rng::AbstractRNG, args...) - # XXX: Ideally the following should just work but currently it gives an illegal - # instruction error. Maybe an issue with Julia's AbsInt? - # seed_uint64 = rand(rng, UInt64, 2) rng = TracedRNG( TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)), rng_algorithm(rng), diff --git a/test/integration/random.jl b/test/integration/random.jl index 128b05954..3c1759617 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -32,18 +32,33 @@ using StatsBase, Statistics, HypothesisTests, Distributions hlo = @code_hlo fn(Reactant.to_rarray(rand(Float64, 2, 3))) @test contains(repr(hlo), "stablehlo.rng_bit_generator") - # XXX: This crashes with Unreachable reached at 0x7a64877b6b16 - # fn2() = begin - # rng = MersenneTwister() - # x = zeros(Float64, 2, 3) - # Random.rand!(rng, x) - # return x - # end - # hlo = @code_hlo fn2() - # @test !contains(repr(hlo), "stablehlo.rng_bit_generator") + fn2() = begin + rng = MersenneTwister() + x = zeros(Float64, 2, 3) + Random.rand!(rng, x) + return x + end + hlo = @code_hlo fn2() + @test !contains(repr(hlo), "stablehlo.rng_bit_generator") end -@testset "Random123" begin end +@testset "Random123" begin + hlo = @code_hlo rand(Random123.Threefry4x(), Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "THREE_FRY") + + hlo = @code_hlo rand(Random123.Threefry2x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "THREE_FRY") + + hlo = @code_hlo rand(Random123.Philox4x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "PHILOX") + + hlo = @code_hlo rand(Random123.Philox2x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "PHILOX") +end # Next we test that the random number generators actually generate data from the correct # distributions From 66b16c4b40500717da4d9bf51643e8976fc02d14 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Dec 2024 23:35:28 +0530 Subject: [PATCH 20/22] fix: try marking TracedRandom in whitelist --- src/stdlibs/Random.jl | 48 ++++++++++++++++++++++++++++--------------- src/utils.jl | 7 ++++--- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index dd2eb82fb..e11904f9b 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -16,7 +16,9 @@ using ..Reactant: ConcreteRArray using Random: Random, AbstractRNG -make_seed(rng::AbstractRNG=Random.RandomDevice()) = rand(rng, UInt64, 2) +@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) + return Base.@invoke rand(rng::AbstractRNG, UInt64, 2) +end function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber @@ -48,18 +50,20 @@ function Random.seed!( return rng end -TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) -TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") +@noinline TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) +@noinline TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") -function default_rng() +@noinline function default_rng() Reactant.within_reactant_interpreter() || return TracedRNG() return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -rng_algorithm(rng::TracedRNG) = rng.algorithm -rng_algorithm(::AbstractRNG) = "DEFAULT" +@noinline rng_algorithm(rng::TracedRNG) = rng.algorithm +@noinline rng_algorithm(::AbstractRNG) = "DEFAULT" -function internal_overload_rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +@noinline function internal_overload_rand!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state @@ -67,7 +71,9 @@ function internal_overload_rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where return A end -function internal_overload_randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +@noinline function internal_overload_randn!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} length(A) == 0 && return A res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state @@ -75,7 +81,9 @@ function internal_overload_randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where return A end -function internal_overload_randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} +@noinline function internal_overload_randexp!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} length(A) == 0 && return A res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state @@ -89,30 +97,36 @@ for randfun in (:rand, :randn, :randexp) overload_randfun! = Symbol(:internal_overload_, randfun!) @eval begin - function $(overload_randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} + @noinline function $(overload_randfun)( + rng::TracedRNG, ::Type{T}, dims::Dims + ) where {T} return $(overload_randfun!)( rng, TracedRArray{T,length(dims)}((), nothing, dims) ) end - function $(overload_randfun)(rng::TracedRNG, dims::Dims) + @noinline function $(overload_randfun)(rng::TracedRNG, dims::Dims) return $(overload_randfun)(rng, Float64, dims) end - function $(overload_randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) + @noinline function $(overload_randfun)( + rng::TracedRNG, dim1::Integer, dims::Integer... + ) return $(overload_randfun)(rng, Dims((dim1, dims...))) end - function $(overload_randfun)( + @noinline function $(overload_randfun)( rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} return $(overload_randfun)(rng, T, Dims((dim1, dims...))) end - $(overload_randfun!)(A::AnyTracedRArray) = $(overload_randfun!)(default_rng(), A) + @noinline function $(overload_randfun!)(A::AnyTracedRArray) + return $(overload_randfun!)(default_rng(), A) + end # scalars - function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} + @noinline function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0))) $(overload_randfun!)(rng, A) return TracedRNumber{T}((), A.mlir_data) @@ -126,7 +140,7 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) overload_randfun = Symbol(:overload_, randfun) internal_overload_randfun = Symbol(:internal_overload_, randfun) @eval begin - function $(overload_randfun)(rng::AbstractRNG, args...) + @noinline function $(overload_randfun)(rng::AbstractRNG, args...) rng = TracedRNG( TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)), rng_algorithm(rng), @@ -134,7 +148,7 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) return $(internal_overload_randfun)(rng, args...) end - function $(overload_randfun)(rng::TracedRNG, args...) + @noinline function $(overload_randfun)(rng::TracedRNG, args...) return $(internal_overload_randfun)(rng, args...) end end diff --git a/src/utils.jl b/src/utils.jl index 16b784d58..b8eb02849 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,7 +99,8 @@ function should_rewrite_ft(@nospecialize(ft)) # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions if has_ancestor(mod, Reactant.Ops) || has_ancestor(mod, Reactant.TracedUtils) || - has_ancestor(mod, Reactant.MLIR) + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Reactant.TracedRandom) return false end end @@ -305,7 +306,7 @@ function call_with_reactant_generator( overdubbed_codelocs = Int32[] # No method could be found (including in our method table), bail with an error - if lookup_result == nothing + if lookup_result === nothing return stub(world, source, method_error) end @@ -501,7 +502,7 @@ function call_with_reactant_generator( # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require takign the function argument. We can work around the latter + # Opaque closures also require taking the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if false && Base.issingletontype(args[1]) res = Core._call_in_world_total( From e370200bfdab4bb6505386916953c82aa973ed92 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Dec 2024 08:43:07 +0530 Subject: [PATCH 21/22] fix: workaround the AbsInt issues for now --- src/Overlay.jl | 10 ++++++++++ src/stdlibs/Random.jl | 7 ++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 15cb5ba74..f566101ff 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -91,5 +91,15 @@ for randfun in (:rand, :randn, :randexp) ) return TracedRandom.$(overload_randfun!)(rng, A) end + + # XXX: Uncomment once AbsInt issues with recursive calls are resolved + # @reactant_overlay @noinline function Random.$(randfun!)( + # rng::AbstractRNG, A::AbstractArray + # ) + # @warn "Directly writing to an array using Random.jl functions inside \ + # ReactantInterpreter will generate a constant array in the IR. Use with \ + # caution." maxlog = 1 + # return Random.$(randfun!)(rng, A) + # end end end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index e11904f9b..271b78f80 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -17,7 +17,12 @@ using ..Reactant: using Random: Random, AbstractRNG @noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) - return Base.@invoke rand(rng::AbstractRNG, UInt64, 2) + # XXX: We should really be able to call this here. But with our AbsInt it leads to a + # segfault. So we'll just call it in the rand! method. + # return rand(rng, UInt64, 2) + seed = Array{UInt64}(undef, 2) + Random.rand!(rng, seed) + return seed end function Random.seed!(rng::TracedRNG, seed::Number) From eaa1f303c4101721490d2c59554666a13a5db0cb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Dec 2024 08:50:52 +0530 Subject: [PATCH 22/22] fix: throw errors for now instead of crashing --- src/Overlay.jl | 30 +++++++++++++++++++++--------- test/integration/random.jl | 6 ++++-- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index f566101ff..b9785b7fa 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -51,9 +51,13 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dims) end - @warn "Reactant doesn't support sampling of $(T) with the current \ - interpreter. Falling back to native interpreter." maxlog = 1 - return Random.$(randfun)(rng, T, dims) + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T, dims) end @reactant_overlay @noinline function Random.$(randfun)( @@ -68,9 +72,13 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end - @warn "Reactant doesn't support sampling of $(T) with the current \ - interpreter. Falling back to native interpreter." maxlog = 1 - return Random.$(randfun)(rng, T, dim1, dims...) + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T, dim1, dims...) end # scalars @@ -80,9 +88,13 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) end - @warn "Reactant doesn't support sampling of $(T) with the current \ - interpreter. Falling back to native interpreter." maxlog = 1 - return Random.$(randfun)(rng, T) + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T) end # inplace diff --git a/test/integration/random.jl b/test/integration/random.jl index 3c1759617..275e0e244 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -25,7 +25,8 @@ using StatsBase, Statistics, HypothesisTests, Distributions @test contains(repr(hlo), "stablehlo.rng_bit_generator") fn(x) = begin - rng = MersenneTwister() + # XXX: MersenneTwister without seed leads to illegal instructions + rng = MersenneTwister(0) Random.rand!(rng, x) return x end @@ -33,7 +34,8 @@ using StatsBase, Statistics, HypothesisTests, Distributions @test contains(repr(hlo), "stablehlo.rng_bit_generator") fn2() = begin - rng = MersenneTwister() + # XXX: MersenneTwister without seed leads to illegal instructions + rng = MersenneTwister(0) x = zeros(Float64, 2, 3) Random.rand!(rng, x) return x