Skip to content

Commit 94bdb43

Browse files
dkarraschmaleadt
andauthored
Adapt to new LinearAlgebra.generic_*mul! interface (#519)
Co-authored-by: Tim Besard <[email protected]>
1 parent 353d2df commit 94bdb43

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

src/host/linalg.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,10 @@ end
338338

339339

340340
## matrix multiplication
341-
342-
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}
341+
# legacy method
342+
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) where {T,S,R} =
343+
generic_matmatmul!(C, A, B, MulAddMul(a, b))
344+
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
343345
if size(A,2) != size(B,1)
344346
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
345347
end
@@ -350,8 +352,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
350352
return fill!(C, zero(R))
351353
end
352354

353-
add = MulAddMul(a, b)
354-
355355
gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B
356356
idx = @linearidx C
357357
assume.(size(C) .> 0)
@@ -372,42 +372,52 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
372372
C
373373
end
374374

375+
@static if VERSION < v"1.12.0-"
375376
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul())
376-
generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta)
377+
generic_matmatmul!(C, wrap(A, tA), B, _add)
377378
end
378379

379380
function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul())
380-
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
381+
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
382+
end
383+
else
384+
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
385+
LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
386+
end
387+
388+
function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number)
389+
LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b))
381390
end
391+
end
382392

383393
if VERSION < v"1.10.0-DEV.1365"
384394
# catch other functions that are called by LinearAlgebra's mul!
385395
function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
386-
generic_matmatmul!(C, wrap(A, tA), B, a, b)
396+
generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
387397
end
388398
# disambiguation
389399
function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat}
390-
generic_matmatmul!(C, wrap(A, tA), B, a, b)
400+
generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
391401
end
392402

393403
LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) =
394-
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
404+
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
395405
# disambiguation
396406
LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} =
397-
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
407+
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
398408

399409
function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
400410
if tA == 'T'
401-
LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add)
411+
generic_matmatmul!(C, wrap(A, 'T'), A, _add)
402412
else # tA == 'N'
403-
LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add)
413+
generic_matmatmul!(C, A, wrap(A, 'T'), _add)
404414
end
405415
end
406416
function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
407417
if tA == 'C'
408-
LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add)
418+
generic_matmatmul!(C, wrap(A, 'C'), A, _add)
409419
else # tA == 'N'
410-
LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add)
420+
generic_matmatmul!(C, A, wrap(A, 'C'), _add)
411421
end
412422
end
413423
end # VERSION

0 commit comments

Comments
 (0)