338
338
339
339
340
340
# # matrix multiplication
341
-
342
- function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , a:: Number , b:: Number ) where {T,S,R}
341
+ # legacy method
342
+ generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) where {T,S,R} =
343
+ generic_matmatmul! (C, A, B, MulAddMul (a, b))
344
+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
343
345
if size (A,2 ) != size (B,1 )
344
346
throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
345
347
end
@@ -350,8 +352,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
350
352
return fill! (C, zero (R))
351
353
end
352
354
353
- add = MulAddMul (a, b)
354
-
355
355
gpu_call (C, A, B; name= " matmatmul!" ) do ctx, C, A, B
356
356
idx = @linearidx C
357
357
assume .(size (C) .> 0 )
@@ -372,42 +372,52 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
372
372
C
373
373
end
374
374
375
+ @static if VERSION < v " 1.12.0-"
375
376
function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
376
- generic_matmatmul! (C, wrap (A, tA), B, _add. alpha, _add . beta )
377
+ generic_matmatmul! (C, wrap (A, tA), B, _add)
377
378
end
378
379
379
380
function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
380
- generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add. alpha, _add. beta)
381
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
382
+ end
383
+ else
384
+ function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
385
+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), B, MulAddMul (a, b))
386
+ end
387
+
388
+ function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389
+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
381
390
end
391
+ end
382
392
383
393
if VERSION < v " 1.10.0-DEV.1365"
384
394
# catch other functions that are called by LinearAlgebra's mul!
385
395
function LinearAlgebra. gemv! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
386
- generic_matmatmul! (C, wrap (A, tA), B, a, b)
396
+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
387
397
end
388
398
# disambiguation
389
399
function LinearAlgebra. gemv! (C:: AbstractGPUVector{T} , tA:: AbstractChar , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVector{T} , a:: Number , b:: Number ) where {T<: LinearAlgebra.BlasFloat }
390
- generic_matmatmul! (C, wrap (A, tA), B, a, b)
400
+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
391
401
end
392
402
393
403
LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul ) =
394
- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
404
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
395
405
# disambiguation
396
406
LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat{T} , B:: AbstractGPUVecOrMat{T} , _add:: MulAddMul ) where {T<: LinearAlgebra.BlasFloat } =
397
- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
407
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
398
408
399
409
function LinearAlgebra. syrk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
400
410
if tA == ' T'
401
- LinearAlgebra . generic_matmatmul! (C, ' T ' , ' N ' , A , A, _add)
411
+ generic_matmatmul! (C, wrap (A , ' T ' ) , A, _add)
402
412
else # tA == 'N'
403
- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' T ' , A, A , _add)
413
+ generic_matmatmul! (C, A, wrap ( A, ' T ' ) , _add)
404
414
end
405
415
end
406
416
function LinearAlgebra. herk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
407
417
if tA == ' C'
408
- LinearAlgebra . generic_matmatmul! (C, ' C ' , ' N ' , A , A, _add)
418
+ generic_matmatmul! (C, wrap (A , ' C ' ) , A, _add)
409
419
else # tA == 'N'
410
- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' C ' , A, A , _add)
420
+ generic_matmatmul! (C, A, wrap ( A, ' C ' ) , _add)
411
421
end
412
422
end
413
423
end # VERSION
0 commit comments