From 8734c809dee55c585402da3eb281950818a415ff Mon Sep 17 00:00:00 2001 From: Juan Ignacio Polanco Date: Fri, 28 Jun 2024 12:59:30 +0200 Subject: [PATCH] Try to fix tests using JLArray --- test/Project.toml | 2 +- test/array_types.jl | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 83d31b14..993294d4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,5 +22,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Compat = "4.5" HDF5 = "0.17" -JLArrays = "0.1.2" +JLArrays = "0.1.5" OrdinaryDiffEq = ">= 6.45" diff --git a/test/array_types.jl b/test/array_types.jl index e6e2c362..c9807009 100644 --- a/test/array_types.jl +++ b/test/array_types.jl @@ -10,18 +10,23 @@ using Test ## ================================================================================ ## -using JLArrays: JLArray, DenseJLVector, DataRef +using JLArrays: JLArray, DenseJLVector, JLVector, DataRef # A bit of type piracy to help tests pass (the following functions seem to be defined for # CuArray). -function Base.resize!(u::DenseJLVector, n) - T = eltype(u) - obj = u.data.rc.obj :: Vector{UInt8} - u.dims = (n,) - resize!(obj, n * sizeof(T)) - @assert sizeof(obj) === sizeof(u) - u + +# This is a modified version of the resize! function defined in JLArrays.jl 0.1.5, which +# avoids freeing memory that will be used in the future. +function Base.resize!(a::DenseJLVector{T}, nl::Integer) where {T} + a_resized = JLVector{T}(undef, nl) + copyto!(a_resized, 1, a, 1, min(length(a), nl)) + finalize(a) # free previous memory + a.data = copy(a_resized.data) # this simply increments the reference count by 1 + a.offset = 0 + a.dims = size(a_resized) + return a end + function Base.unsafe_wrap(::Type{JLArray}, p::Ptr, dims::Dims; kws...) T = eltype(p) N = length(dims) @@ -33,8 +38,10 @@ function Base.unsafe_wrap(::Type{JLArray}, p::Ptr, dims::Dims; kws...) @assert pointer(x) === p x end + Base.unsafe_wrap(::Type{JLArray}, p::Ptr, n::Integer; kws...) = unsafe_wrap(JLArray, p, (n,); kws...) + # Random.rand!(rng::AbstractRNG, u::JLArray, ::Type{X}) where {X} = (rand!(rng, u.data, X); u) # For some reason this kind of view doesn't work correctly in the original implementation, @@ -124,16 +131,30 @@ MPI.Comm_rank(comm) == 0 || redirect_stdout(devnull) @assert inv(perm) != perm end py = @inferred Pencil(px; decomp_dims = (2,), permute = perm) + @test px.send_buf === py.send_buf @test permutation(py) == perm @test @inferred(typeof_array(px)) === A @test @inferred(typeof_array(py)) === A + if A === JLArray + GC.gc() + @test px.send_buf.data.rc.count[] == 1 + end + @testset "Transpositions" begin ux = @test_nowarn rand!(rng, PencilArray{Float64}(undef, px)) uy = @inferred similar(ux, py) @test pencil(uy) === py tr = @inferred Transpositions.Transposition(uy, ux) + if A === JLArray + GC.gc() + @test px.send_buf.data.rc.count[] == 1 + end transpose!(tr) + if A === JLArray + GC.gc() + @test px.send_buf.data.rc.count[] == 1 + end @test_logs (:warn, r"is deprecated") MPI.Waitall!(tr) # Verify transposition