diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index 130f039bc..9423a6fb6 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -87,6 +87,7 @@ function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Inte Base.stride(A.parent, d) end +Base.pointer(A::BatchedAdjOrTrans) = pointer(parent(A)) Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(A)) diff --git a/src/gemm.jl b/src/gemm.jl index 051508750..3552f3124 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -80,9 +80,9 @@ for (gemm, elt) in gemm_datatype_mappings LinearAlgebra.BLAS.chkstride1(B) LinearAlgebra.BLAS.chkstride1(C) - ptrA = Base.unsafe_convert(Ptr{$elt}, A) - ptrB = Base.unsafe_convert(Ptr{$elt}, B) - ptrC = Base.unsafe_convert(Ptr{$elt}, C) + ptrA = pointer(A) + ptrB = pointer(B) + ptrC = pointer(C) strA = size(A, 3) == 1 ? 0 : Base.stride(A, 3) strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3)