Skip to content

Commit 63aa464

Browse files
committed
Revert "Faster matmul"
This reverts commit 416b28f.
1 parent 1b7728d commit 63aa464

File tree

1 file changed

+7
-88
lines changed

1 file changed

+7
-88
lines changed

src/host/linalg.jl

Lines changed: 7 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -325,92 +325,11 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325325
B
326326
end
327327

328-
# XXX: figure out how to do dynamically
329-
MAX_TILE_DIM = 16
330328

331329
## matrix multiplication
332330
# legacy method
333331
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) =
334332
generic_matmatmul!(C, A, B, MulAddMul(a, b))
335-
function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number}
336-
N = size(A,1)
337-
Q = size(A,2)
338-
M = size(B,2)
339-
if Q != size(B,1)
340-
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
341-
end
342-
if size(C,1) != N || size(C,2) != M
343-
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))"))
344-
end
345-
if isempty(A) || isempty(B)
346-
return fill!(C, zero(R))
347-
end
348-
349-
@kernel unsafe_indices=true function coalesced_matmul_kernel!(
350-
output, @Const(input1), @Const(input2), N, Q, M,
351-
::Val{BANK} = Val(1),
352-
) where {BANK}
353-
grow, gcol = @index(Group, NTuple)
354-
tile_row, tile_col = @index(Local, NTuple)
355-
356-
TILE_DIM = @uniform @groupsize()[1]
357-
358-
# +1 to avoid bank conflicts on shared memory
359-
tile1 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
360-
tile2 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
361-
362-
# private variable for tile output
363-
outval = @private R 1
364-
@inbounds outval[1] = -zero(R)
365-
366-
# number of tiles depends on inner dimension
367-
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
368-
369-
# loop over all tiles needed for this calculation
370-
for t in 0:(NUM_TILES - 1)
371-
I = (grow - 1) * TILE_DIM + tile_row
372-
J = (gcol - 1) * TILE_DIM + tile_col
373-
374-
# load inputs into tiles, with bounds checking for non-square matrices
375-
if I <= N && t * TILE_DIM + tile_col <= Q
376-
@inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377-
else
378-
@inbounds tile1[tile_row, tile_col] = zero(R)
379-
end
380-
if J <= M && t * TILE_DIM + tile_row <= Q
381-
@inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382-
else
383-
@inbounds tile2[tile_row, tile_col] = zero(R)
384-
end
385-
386-
# wait for all tiles to be loaded
387-
@synchronize
388-
389-
I = (grow - 1) * TILE_DIM + tile_row
390-
J = (gcol - 1) * TILE_DIM + tile_col
391-
392-
# calculate value of spot in output, use temporary value to allow for vectorization
393-
out = zero(R)
394-
@simd for k in 1:TILE_DIM
395-
@inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
396-
end
397-
outval[1] += out
398-
399-
@synchronize
400-
end
401-
402-
I = (grow - 1) * TILE_DIM + tile_row
403-
J = (gcol - 1) * TILE_DIM + tile_col
404-
405-
# save if inbounds
406-
if I <= N && J <= M
407-
@inbounds output[I, J] = add(outval[1], output[I, J])
408-
end
409-
end
410-
411-
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
412-
C
413-
end
414333
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
415334
if size(A,2) != size(B,1)
416335
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
@@ -825,7 +744,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
825744

826745
@kernel function kron_kernel!(z, @Const(x), @Const(y))
827746
i, j = @index(Global, NTuple)
828-
747+
829748
@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
830749
end
831750

@@ -858,13 +777,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
858777

859778
ta = $transa(T1)
860779
tb = $transb(T2)
861-
780+
862781
@kernel function kron_kernel!(C, @Const(A), @Const(B))
863782
ai, aj = @index(Global, NTuple) # Indices in the result matrix
864-
783+
865784
# lb1, lb2 = size(B) # Dimensions of B
866785
lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B))
867-
786+
868787
# Map global indices (ai, aj) to submatrices of the Kronecker product
869788
i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A
870789
i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B
@@ -878,12 +797,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
878797
C[ai, aj] = a_ij * b_ij
879798
end
880799
end
881-
800+
882801
backend = KernelAbstractions.get_backend(C)
883802
kernel = kron_kernel!(backend)
884-
803+
885804
kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2)))
886-
805+
887806
return C
888807
end
889808

0 commit comments

Comments
 (0)