From 2230b412e4b2e7187f6515a602a63ff1d968a2a5 Mon Sep 17 00:00:00 2001 From: matinraayai <30674652+matinraayai@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:06:56 -0500 Subject: [PATCH 1/3] Preliminary fix for Base.unsafe_warp --- src/array.jl | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/array.jl b/src/array.jl index b20402eb9..0971238be 100644 --- a/src/array.jl +++ b/src/array.jl @@ -189,15 +189,48 @@ end # TODO docs function Base.unsafe_wrap( - ::Type{<:ROCArray}, ptr::Ptr{T}, dims::NTuple{N, <:Integer}; - lock::Bool = true, + ::Union{Type{ROCArray},Type{ROCArray{T}},Type{ROCArray{T,1}}}, ptr::Ptr{T}, dims::NTuple{N, <:Integer}; + own::Bool = false, ) where {T,N} @assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray" + buf = _unsafe_wrap_managed(ptr, dims) + ROCArray{T, N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) +end + +function Base.unsafe_wrap(::Type{CuArray{T,N,B}}, ptr::Ptr{T}, dims::NTuple{N,<:Integer}; + own::Bool=false, +) where {T,N,B} + buf = _unsafe_wrap_identify_buffer(ptr, dims) + if typeof(buf) !== B + throw(ArgumentError("Declared buffer type does not match inferred buffer type.")) + end + ROCArray{T,N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) +end + +function _unsafe_wrap_identify_buffer(ptr::Ptr{T}, dims::NTuple{N, <:Integer}) where {T, N} + (status, attrs) = Mem.attributes(ptr) + if status != HIP.hipSuccess # The pointer is unknown to the HIP runtime + error("The HIP runtime could not identify the attributes of the ptr $ptr.") + end + device = HIPDevice(attrs.device + 1) + context = HIPContext(device) sz = prod(dims) * sizeof(T) - buf = lock ? - Mem.HostBuffer(Ptr{Cvoid}(ptr), sz) : - Mem.HIPBuffer(Ptr{Cvoid}(ptr), sz) - ROCArray{T, N}(DataRef(_free_buf, buf), dims) + if attrs.memoryType == HIP.hipMemoryTypeHost + return Mem.HostBuffer(device, context, attrs.hostPointer, attrs.devicePointer, sz, own) + elseif attrs.memoryType == hipMemoryTypeDevice + return Mem.HIPBuffer(device, context, ptr, sz, own) + else + error("Memory type $(attrs.memoryType) is not supported by AMDGPU.jl") + end +end + +# integer size input +function Base.unsafe_wrap(::Union{Type{ROCArray},Type{ROCArray{T}},Type{ROCArray{T,1}}}, + p::Ptr{T}, dim::Int) where {T} + unsafe_wrap(ROCArray{T,1}, p, (dim,)) +end +function Base.unsafe_wrap(::Type{ROCArray{T,1,B}}, p::Ptr{T}, dim::Int) where {T,B} + unsafe_wrap(ROCArray{T,1,B}, p, (dim,)) end Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T = From 2617f7c1f6c40990ce3ad9494ca781c91fb97d73 Mon Sep 17 00:00:00 2001 From: matinraayai <30674652+matinraayai@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:12:27 -0500 Subject: [PATCH 2/3] Added missing own keyword argument + removed CuArray --- src/array.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/array.jl b/src/array.jl index 0971238be..586d54789 100644 --- a/src/array.jl +++ b/src/array.jl @@ -197,7 +197,7 @@ function Base.unsafe_wrap( ROCArray{T, N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) end -function Base.unsafe_wrap(::Type{CuArray{T,N,B}}, ptr::Ptr{T}, dims::NTuple{N,<:Integer}; +function Base.unsafe_wrap(::Type{ROCArray{T,N,B}}, ptr::Ptr{T}, dims::NTuple{N,<:Integer}; own::Bool=false, ) where {T,N,B} buf = _unsafe_wrap_identify_buffer(ptr, dims) @@ -226,11 +226,11 @@ end # integer size input function Base.unsafe_wrap(::Union{Type{ROCArray},Type{ROCArray{T}},Type{ROCArray{T,1}}}, - p::Ptr{T}, dim::Int) where {T} - unsafe_wrap(ROCArray{T,1}, p, (dim,)) + p::Ptr{T}, dim::Int; own=false) where {T} + unsafe_wrap(ROCArray{T,1}, p, (dim,); own=own) end -function Base.unsafe_wrap(::Type{ROCArray{T,1,B}}, p::Ptr{T}, dim::Int) where {T,B} - unsafe_wrap(ROCArray{T,1,B}, p, (dim,)) +function Base.unsafe_wrap(::Type{ROCArray{T,1,B}}, p::Ptr{T}, dim::Int; own=false) where {T,B} + unsafe_wrap(ROCArray{T,1,B}, p, (dim,); own=own) end Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T = From 2cd01057e70f38def0786ce4777cbb4d0c10f5f4 Mon Sep 17 00:00:00 2001 From: matinraayai <30674652+matinraayai@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:50:33 -0500 Subject: [PATCH 3/3] Bug fix. --- src/array.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/array.jl b/src/array.jl index 586d54789..455065afa 100644 --- a/src/array.jl +++ b/src/array.jl @@ -193,21 +193,21 @@ function Base.unsafe_wrap( own::Bool = false, ) where {T,N} @assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray" - buf = _unsafe_wrap_managed(ptr, dims) + buf = _unsafe_wrap_identify_buffer(ptr, dims, own) ROCArray{T, N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) end function Base.unsafe_wrap(::Type{ROCArray{T,N,B}}, ptr::Ptr{T}, dims::NTuple{N,<:Integer}; own::Bool=false, ) where {T,N,B} - buf = _unsafe_wrap_identify_buffer(ptr, dims) + buf = _unsafe_wrap_identify_buffer(ptr, dims, own) if typeof(buf) !== B throw(ArgumentError("Declared buffer type does not match inferred buffer type.")) end ROCArray{T,N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) end -function _unsafe_wrap_identify_buffer(ptr::Ptr{T}, dims::NTuple{N, <:Integer}) where {T, N} +function _unsafe_wrap_identify_buffer(ptr::Ptr{T}, dims::NTuple{N, <:Integer}, own::Bool) where {T, N} (status, attrs) = Mem.attributes(ptr) if status != HIP.hipSuccess # The pointer is unknown to the HIP runtime error("The HIP runtime could not identify the attributes of the ptr $ptr.") @@ -217,7 +217,7 @@ function _unsafe_wrap_identify_buffer(ptr::Ptr{T}, dims::NTuple{N, <:Integer}) w sz = prod(dims) * sizeof(T) if attrs.memoryType == HIP.hipMemoryTypeHost return Mem.HostBuffer(device, context, attrs.hostPointer, attrs.devicePointer, sz, own) - elseif attrs.memoryType == hipMemoryTypeDevice + elseif attrs.memoryType == HIP.hipMemoryTypeDevice return Mem.HIPBuffer(device, context, ptr, sz, own) else error("Memory type $(attrs.memoryType) is not supported by AMDGPU.jl") @@ -233,7 +233,7 @@ function Base.unsafe_wrap(::Type{ROCArray{T,1,B}}, p::Ptr{T}, dim::Int; own=fals unsafe_wrap(ROCArray{T,1,B}, p, (dim,); own=own) end -Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T = +Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims::NTuple{N,<:Integer}; kwargs...) where {T, N} = unsafe_wrap(ROCArray, Base.unsafe_convert(Ptr{T}, ptr), dims; kwargs...) ## interop with CPU arrays