Skip to content

Commit

Permalink
Try to fix tests using JLArray
Browse files Browse the repository at this point in the history
  • Loading branch information
jipolanco committed Jun 28, 2024
1 parent e7e6dbc commit 8734c80
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
Compat = "4.5"
HDF5 = "0.17"
JLArrays = "0.1.2"
JLArrays = "0.1.5"
OrdinaryDiffEq = ">= 6.45"
37 changes: 29 additions & 8 deletions test/array_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@ using Test

## ================================================================================ ##

using JLArrays: JLArray, DenseJLVector, DataRef
using JLArrays: JLArray, DenseJLVector, JLVector, DataRef

# A bit of type piracy to help tests pass (the following functions seem to be defined for
# CuArray).
function Base.resize!(u::DenseJLVector, n)
T = eltype(u)
obj = u.data.rc.obj :: Vector{UInt8}
u.dims = (n,)
resize!(obj, n * sizeof(T))
@assert sizeof(obj) === sizeof(u)
u

# This is a modified version of the resize! function defined in JLArrays.jl 0.1.5, which
# avoids freeing memory that will be used in the future.
function Base.resize!(a::DenseJLVector{T}, nl::Integer) where {T}
a_resized = JLVector{T}(undef, nl)
copyto!(a_resized, 1, a, 1, min(length(a), nl))
finalize(a) # free previous memory
a.data = copy(a_resized.data) # this simply increments the reference count by 1
a.offset = 0
a.dims = size(a_resized)
return a
end

function Base.unsafe_wrap(::Type{JLArray}, p::Ptr, dims::Dims; kws...)
T = eltype(p)
N = length(dims)
Expand All @@ -33,8 +38,10 @@ function Base.unsafe_wrap(::Type{JLArray}, p::Ptr, dims::Dims; kws...)
@assert pointer(x) === p
x
end

Base.unsafe_wrap(::Type{JLArray}, p::Ptr, n::Integer; kws...) =
unsafe_wrap(JLArray, p, (n,); kws...)

# Random.rand!(rng::AbstractRNG, u::JLArray, ::Type{X}) where {X} = (rand!(rng, u.data, X); u)

# For some reason this kind of view doesn't work correctly in the original implementation,
Expand Down Expand Up @@ -124,16 +131,30 @@ MPI.Comm_rank(comm) == 0 || redirect_stdout(devnull)
@assert inv(perm) != perm
end
py = @inferred Pencil(px; decomp_dims = (2,), permute = perm)
@test px.send_buf === py.send_buf
@test permutation(py) == perm
@test @inferred(typeof_array(px)) === A
@test @inferred(typeof_array(py)) === A

if A === JLArray
GC.gc()
@test px.send_buf.data.rc.count[] == 1
end

@testset "Transpositions" begin
ux = @test_nowarn rand!(rng, PencilArray{Float64}(undef, px))
uy = @inferred similar(ux, py)
@test pencil(uy) === py
tr = @inferred Transpositions.Transposition(uy, ux)
if A === JLArray
GC.gc()
@test px.send_buf.data.rc.count[] == 1
end
transpose!(tr)
if A === JLArray
GC.gc()
@test px.send_buf.data.rc.count[] == 1
end
@test_logs (:warn, r"is deprecated") MPI.Waitall!(tr)

# Verify transposition
Expand Down

0 comments on commit 8734c80

Please sign in to comment.