diff --git a/src/array.jl b/src/array.jl index b20402eb9..455065afa 100644 --- a/src/array.jl +++ b/src/array.jl @@ -189,18 +189,51 @@ end # TODO docs function Base.unsafe_wrap( - ::Type{<:ROCArray}, ptr::Ptr{T}, dims::NTuple{N, <:Integer}; - lock::Bool = true, + ::Union{Type{ROCArray},Type{ROCArray{T}},Type{ROCArray{T,1}}}, ptr::Ptr{T}, dims::NTuple{N, <:Integer}; + own::Bool = false, ) where {T,N} @assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray" + buf = _unsafe_wrap_identify_buffer(ptr, dims, own) + ROCArray{T, N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) +end + +function Base.unsafe_wrap(::Type{ROCArray{T,N,B}}, ptr::Ptr{T}, dims::NTuple{N,<:Integer}; + own::Bool=false, +) where {T,N,B} + buf = _unsafe_wrap_identify_buffer(ptr, dims, own) + if typeof(buf) !== B + throw(ArgumentError("Declared buffer type does not match inferred buffer type.")) + end + ROCArray{T,N}(DataRef(own ? _free_buffer : (args...) -> (), buf), dims) +end + +function _unsafe_wrap_identify_buffer(ptr::Ptr{T}, dims::NTuple{N, <:Integer}, own::Bool) where {T, N} + (status, attrs) = Mem.attributes(ptr) + if status != HIP.hipSuccess # The pointer is unknown to the HIP runtime + error("The HIP runtime could not identify the attributes of the ptr $ptr.") + end + device = HIPDevice(attrs.device + 1) + context = HIPContext(device) sz = prod(dims) * sizeof(T) - buf = lock ? - Mem.HostBuffer(Ptr{Cvoid}(ptr), sz) : - Mem.HIPBuffer(Ptr{Cvoid}(ptr), sz) - ROCArray{T, N}(DataRef(_free_buf, buf), dims) + if attrs.memoryType == HIP.hipMemoryTypeHost + return Mem.HostBuffer(device, context, attrs.hostPointer, attrs.devicePointer, sz, own) + elseif attrs.memoryType == HIP.hipMemoryTypeDevice + return Mem.HIPBuffer(device, context, ptr, sz, own) + else + error("Memory type $(attrs.memoryType) is not supported by AMDGPU.jl") + end +end + +# integer size input +function Base.unsafe_wrap(::Union{Type{ROCArray},Type{ROCArray{T}},Type{ROCArray{T,1}}}, + p::Ptr{T}, dim::Int; own=false) where {T} + unsafe_wrap(ROCArray{T,1}, p, (dim,); own=own) +end +function Base.unsafe_wrap(::Type{ROCArray{T,1,B}}, p::Ptr{T}, dim::Int; own=false) where {T,B} + unsafe_wrap(ROCArray{T,1,B}, p, (dim,); own=own) end -Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T = +Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims::NTuple{N,<:Integer}; kwargs...) where {T, N} = unsafe_wrap(ROCArray, Base.unsafe_convert(Ptr{T}, ptr), dims; kwargs...) ## interop with CPU arrays