@@ -325,92 +325,11 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325
325
B
326
326
end
327
327
328
- # XXX : figure out how to do dynamically
329
- MAX_TILE_DIM = 16
330
328
331
329
# # matrix multiplication
332
330
# legacy method
333
331
generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
334
332
generic_matmatmul! (C, A, B, MulAddMul (a, b))
335
- function generic_matmatmul! (C:: AbstractGPUMatrix{R} , A:: AbstractGPUMatrix{T} , B:: AbstractGPUMatrix{S} , add:: MulAddMul ) where {T<: Number ,S<: Number ,R<: Number }
336
- N = size (A,1 )
337
- Q = size (A,2 )
338
- M = size (B,2 )
339
- if Q != size (B,1 )
340
- throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
341
- end
342
- if size (C,1 ) != N || size (C,2 ) != M
343
- throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M)) " ))
344
- end
345
- if isempty (A) || isempty (B)
346
- return fill! (C, zero (R))
347
- end
348
-
349
- @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350
- output, @Const (input1), @Const (input2), N, Q, M,
351
- :: Val{BANK} = Val (1 ),
352
- ) where {BANK}
353
- grow, gcol = @index (Group, NTuple)
354
- tile_row, tile_col = @index (Local, NTuple)
355
-
356
- TILE_DIM = @uniform @groupsize ()[1 ]
357
-
358
- # +1 to avoid bank conflicts on shared memory
359
- tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
360
- tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
361
-
362
- # private variable for tile output
363
- outval = @private R 1
364
- @inbounds outval[1 ] = - zero (R)
365
-
366
- # number of tiles depends on inner dimension
367
- @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
368
-
369
- # loop over all tiles needed for this calculation
370
- for t in 0 : (NUM_TILES - 1 )
371
- I = (grow - 1 ) * TILE_DIM + tile_row
372
- J = (gcol - 1 ) * TILE_DIM + tile_col
373
-
374
- # load inputs into tiles, with bounds checking for non-square matrices
375
- if I <= N && t * TILE_DIM + tile_col <= Q
376
- @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377
- else
378
- @inbounds tile1[tile_row, tile_col] = zero (R)
379
- end
380
- if J <= M && t * TILE_DIM + tile_row <= Q
381
- @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382
- else
383
- @inbounds tile2[tile_row, tile_col] = zero (R)
384
- end
385
-
386
- # wait for all tiles to be loaded
387
- @synchronize
388
-
389
- I = (grow - 1 ) * TILE_DIM + tile_row
390
- J = (gcol - 1 ) * TILE_DIM + tile_col
391
-
392
- # calculate value of spot in output, use temporary value to allow for vectorization
393
- out = zero (R)
394
- @simd for k in 1 : TILE_DIM
395
- @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
396
- end
397
- outval[1 ] += out
398
-
399
- @synchronize
400
- end
401
-
402
- I = (grow - 1 ) * TILE_DIM + tile_row
403
- J = (gcol - 1 ) * TILE_DIM + tile_col
404
-
405
- # save if inbounds
406
- if I <= N && J <= M
407
- @inbounds output[I, J] = add (outval[1 ], output[I, J])
408
- end
409
- end
410
-
411
- coalesced_matmul_kernel! (get_backend (C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ MAX_TILE_DIM)* MAX_TILE_DIM, size (C)))
412
- C
413
- end
414
333
function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
415
334
if size (A,2 ) != size (B,1 )
416
335
throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
@@ -825,7 +744,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
825
744
826
745
@kernel function kron_kernel! (z, @Const (x), @Const (y))
827
746
i, j = @index (Global, NTuple)
828
-
747
+
829
748
@inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
830
749
end
831
750
@@ -858,13 +777,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
858
777
859
778
ta = $ transa (T1)
860
779
tb = $ transb (T2)
861
-
780
+
862
781
@kernel function kron_kernel! (C, @Const (A), @Const (B))
863
782
ai, aj = @index (Global, NTuple) # Indices in the result matrix
864
-
783
+
865
784
# lb1, lb2 = size(B) # Dimensions of B
866
785
lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
867
-
786
+
868
787
# Map global indices (ai, aj) to submatrices of the Kronecker product
869
788
i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
870
789
i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
@@ -878,12 +797,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
878
797
C[ai, aj] = a_ij * b_ij
879
798
end
880
799
end
881
-
800
+
882
801
backend = KernelAbstractions. get_backend (C)
883
802
kernel = kron_kernel! (backend)
884
-
803
+
885
804
kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
886
-
805
+
887
806
return C
888
807
end
889
808
0 commit comments