From 416b28f11e0f4b509e6efca1a06babcd4997d45b Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:00:48 -0300 Subject: [PATCH] Faster matmul --- src/host/linalg.jl | 95 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 2c928747..b59598f6 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -325,11 +325,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat, B end +# XXX: figure out how to do dynamically +MAX_TILE_DIM = 16 ## matrix multiplication # legacy method generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) = generic_matmatmul!(C, A, B, MulAddMul(a, b)) +function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number} + N = size(A,1) + Q = size(A,2) + M = size(B,2) + if Q != size(B,1) + throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != N || size(C,2) != M + throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + @kernel unsafe_indices=true function coalesced_matmul_kernel!( + output, @Const(input1), @Const(input2), N, Q, M, + ::Val{BANK} = Val(1), + ) where {BANK} + grow, gcol = @index(Group, NTuple) + tile_row, tile_col = @index(Local, NTuple) + + TILE_DIM = @uniform @groupsize()[1] + + # +1 to avoid bank conflicts on shared memory + tile1 = @localmem(R, (TILE_DIM + BANK, TILE_DIM)) + tile2 = @localmem(R, (TILE_DIM + BANK, TILE_DIM)) + + # private variable for tile output + outval = @private R 1 + @inbounds outval[1] = -zero(R) + + # number of tiles depends on inner dimension + @uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM) + + # loop over all tiles needed for this calculation + for t in 0:(NUM_TILES - 1) + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # load inputs into tiles, with bounds checking for non-square matrices + if I <= N && t * TILE_DIM + tile_col <= Q + @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col] + else + @inbounds tile1[tile_row, tile_col] = zero(R) + end + if J <= M && t * TILE_DIM + tile_row <= Q + @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J] + else + @inbounds tile2[tile_row, tile_col] = zero(R) + end + + # wait for all tiles to be loaded + @synchronize + + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # calculate value of spot in output, use temporary value to allow for vectorization + out = zero(R) + @simd for k in 1:TILE_DIM + @inbounds out += tile1[tile_row, k] * tile2[k, tile_col] + end + outval[1] += out + + @synchronize + end + + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # save if inbounds + if I <= N && J <= M + @inbounds output[I, J] = add(outval[1], output[I, J]) + end + end + + coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C))) + C +end function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R} if size(A,2) != size(B,1) throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) @@ -744,7 +825,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2}, @kernel function kron_kernel!(z, @Const(x), @Const(y)) i, j = @index(Global, NTuple) - + @inbounds z[(i - 1) * length(y) + j] = x[i] * y[j] end @@ -777,13 +858,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in ta = $transa(T1) tb = $transb(T2) - + @kernel function kron_kernel!(C, @Const(A), @Const(B)) ai, aj = @index(Global, NTuple) # Indices in the result matrix - + # lb1, lb2 = size(B) # Dimensions of B lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B)) - + # Map global indices (ai, aj) to submatrices of the Kronecker product i_a = (ai - 1) รท lb1 + 1 # Corresponding row index in A i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B @@ -797,12 +878,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in C[ai, aj] = a_ij * b_ij end end - + backend = KernelAbstractions.get_backend(C) kernel = kron_kernel!(backend) - + kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2))) - + return C end