Skip to content

Commit

Permalink
corrected rand! for mutable static arrays and moved them to an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
schrimpf committed Oct 3, 2023
1 parent c34cf44 commit 13741c2
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 90 deletions.
17 changes: 11 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
name = "VectorizedRNG"
uuid = "33b4df10-0173-11e9-2a0c-851a7edac40e"
authors = ["Chris Elrod <[email protected]>"]
version = "0.2.24"
version = "0.2.25"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

[weakdeps]
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
VectorizedRNGStaticArraysExt = ["StaticArraysCore"]

[compat]
Requires = "1"
SLEEFPirates = "0.6.29"
StaticArraysCore = "1"
StrideArraysCore = "0.3, 0.4"
UnPack = "1"
VectorizationBase = "0.19.38, 0.20.1, 0.21"
julia = "1.6"

[extras]
RNGTest = "97cc5700-e6cb-5ca1-8fb2-7f6b45264ecd"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "RNGTest"]
test = ["Test", "RNGTest", "StaticArrays"]
55 changes: 55 additions & 0 deletions ext/VectorizedRNGStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
module VectorizedRNGStaticArraysExt


using StaticArraysCore
if isdefined(Base, :get_extension)
using VectorizedRNG: samplevector!, random_uniform, random_normal, AbstractVRNG, random_unsigned
else
using ..VectorizedRNG: samplevector!, random_uniform, random_normal, AbstractVRNG, random_unsigned
end
using VectorizationBase: StaticInt
import Random

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T},
α::Number = StaticInt{0}(),
β = StaticInt{0}(),
γ = StaticInt{1}()
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_uniform, rng, x, α, β, γ, identity)
end
return x
end

function Random.randn!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T},
α::Number = StaticInt{0}(),
β = StaticInt{0}(),
γ = StaticInt{1}()
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_normal, rng, x, α, β, γ, identity)
end
return x
end

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,UInt64}
)
samplevector!(
random_unsigned,
rng,
x,
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end


end
9 changes: 9 additions & 0 deletions src/VectorizedRNG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ include("api.jl")
include("special_approximations.jl")
include("xoshiro.jl")
# const GLOBAL_vPCGs = Ref{Ptr{UInt64}}()
if !isdefined(Base, :get_extension)
include("../ext/VectorizedRNGStaticArraysExt.jl")
end


const GLOBAL_vRNGs = Ref{Ptr{UInt64}}()

Expand All @@ -119,6 +123,11 @@ function __init()
end
end
function __init__()
@static if !isdefined(Base, :get_extension)
@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin
include("../ext/VectorizedRNGStaticArraysExt.jl")
end
end
ccall(:jl_generating_output, Cint, ()) == 1 && return
__init()
end
Expand Down
83 changes: 2 additions & 81 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,77 +375,16 @@ function Random.randn!(
) where {T<:Union{Float32,Float64}}
samplevector!(random_normal, rng, x, α, β, γ, identity)
end

@inline function random_unsigned(
state::AbstractState,
::Val{N},
::Type{T}
) where {N,T}
nextstate(state, Val{N}())
end
function Random.rand!(rng::AbstractVRNG, x::AbstractArray{UInt64})
samplevector!(
random_unsigned,
rng,
x,
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end

using StaticArraysCore, StrideArraysCore
function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T}
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_uniform, rng, PtrArray(x), α, β, γ, identity)
end
return x
end
function Random.rand!(
rng::AbstractVRNG,
x::SA
) where {
S<:Tuple,
T<:Union{Float32,Float64},
SA<:StaticArraysCore.StaticArray{S,T}
}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(random_uniform, rng, PtrArray(a), α, β, γ, identity)
end
x .= a
end
function Random.randn!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T}
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_normal, rng, PtrArray(x), α, β, γ, identity)
end
return x
end
function Random.randn!(
rng::AbstractVRNG,
x::SA
) where {
S<:Tuple,
T<:Union{Float32,Float64},
SA<:StaticArraysCore.StaticArray{S,T}
}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(random_normal, rng, PtrArray(a), α, β, γ, identity)
end
x .= a
end

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,UInt64}
)
function Random.rand!(rng::AbstractVRNG, x::AbstractArray{UInt64})
samplevector!(
random_unsigned,
rng,
Expand All @@ -456,24 +395,6 @@ function Random.rand!(
identity
)
end
function Random.rand!(
rng::AbstractVRNG,
x::SA
) where {S<:Tuple,SA<:StaticArraysCore.StaticArray{S,UInt64}}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(
random_unsigned,
rng,
PtrArray(a),
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end
x .= a
end

Random.rand(rng::AbstractVRNG, d1::Integer, dims::Vararg{Integer,N}) where {N} =
rand!(rng, Array{Float64}(undef, d1, dims...))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
RNGTest = "97cc5700-e6cb-5ca1-8fb2-7f6b45264ecd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31 changes: 28 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test
using InteractiveUtils: versioninfo
versioninfo(; verbose = true)

using RNGTest, Random, SpecialFunctions, Aqua, Distributions
using RNGTest, Random, SpecialFunctions, Aqua, Distributions, StaticArrays

const α = 1e-4

Expand Down Expand Up @@ -171,11 +171,36 @@ end
vrng = local_rng()
σ = 0.5
for i = 1:N
randn!(vrng, x, VectorizedRNG.static(0), VectorizedRNG.static(0), σ)
randn!(vrng, x, VectorizedRNG.StaticInt(0), VectorizedRNG.StaticInt(0), σ)
s += std(x)
end
s /= N
@test s σ rtol = 1e-1
end
end
end

@testset "StaticArrays" begin
seed = 1234
rng = local_rng()
for T in (Float32, Float64, UInt64, Int)
for dim in ((10),(10,10), (10,10,10))
A = zeros(T, dim)
mA = MArray{Tuple{dim...}}(A)
VectorizedRNG.seed!(seed)
rand!(rng, A)
VectorizedRNG.seed!(seed)
rand!(rng, mA)
@test all(A .== mA)

if T <: AbstractFloat
VectorizedRNG.seed!(seed)
randn!(rng, A)
VectorizedRNG.seed!(seed)
randn!(rng, mA)
@test all(A .== mA)
end
end
end
end

end

0 comments on commit 13741c2

Please sign in to comment.