@@ -396,89 +396,3 @@ for op in (:(+), :(-))
396
396
@eval Base.$ op (A:: $TypeA , B:: $TypeB ) where {T <: CublasFloat } = geam ($ transa (T), $ transb (T), one (T), $ (unwrapa (:A )), $ (op)(one (T)), $ (unwrapb (:B )))
397
397
end
398
398
end
399
-
400
- # Kronecker product
401
- function LinearAlgebra. kron! (C:: CuMatrix{TC} , A:: CuMatrix{TA} , B:: CuMatrix{TB} ) where {TA,TB,TC}
402
-
403
- function _kron_mat_kernelA! (C, A, B, m, n, p, q)
404
- index_i = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
405
- index_j = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
406
-
407
- stride_i = blockDim (). x * gridDim (). x
408
- stride_j = blockDim (). y * gridDim (). y
409
-
410
- index_i > m && return
411
- index_j > n && return
412
-
413
- for i in index_i: stride_i: m
414
- for j in index_j: stride_j: n
415
- for k in 1 : p
416
- for l in 1 : q
417
- @inbounds C[(i- 1 )* p+ k, (j- 1 )* q+ l] = A[i,j] * B[k,l]
418
- end
419
- end
420
- end
421
- end
422
- return nothing
423
- end
424
-
425
- function _kron_mat_kernelB! (C, A, B, m, n, p, q)
426
- index_p = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
427
- index_q = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
428
-
429
- stride_p = blockDim (). x * gridDim (). x
430
- stride_q = blockDim (). y * gridDim (). y
431
-
432
- index_p > p && return
433
- index_q > q && return
434
-
435
- for i in 1 : m
436
- for j in 1 : n
437
- for k in index_p: stride_p: p
438
- for l in index_q: stride_q: q
439
- @inbounds C[(i- 1 )* p+ k, (j- 1 )* q+ l] = A[i,j] * B[k,l]
440
- end
441
- end
442
- end
443
- end
444
- return nothing
445
- end
446
-
447
- m, n = size (A)
448
- p, q = size (B)
449
-
450
- # Use different kernels depending on the size of the matrices
451
- # choosing to parallelize the matrix with the largest number of elements
452
- m* n >= p* q ? (kernel = @cuda launch= false _kron_mat_kernelA! (C, A, B, m, n, p, q)) :
453
- (kernel = @cuda launch= false _kron_mat_kernelB! (C, A, B, m, n, p, q))
454
-
455
- m* n >= p* q ? (sizes = (m, n)) : (sizes = (p, q))
456
-
457
- config = launch_configuration (kernel. fun)
458
- dim_ratio = sizes[1 ] / sizes[2 ]
459
- max_threads_i = max (1 , floor (Int, sqrt (config. threads * dim_ratio)))
460
- max_threads_j = max (1 , floor (Int, sqrt (config. threads / dim_ratio)))
461
- max_blocks_i = max (1 , floor (Int, sqrt (config. blocks * dim_ratio)))
462
- max_blocks_j = max (1 , floor (Int, sqrt (config. blocks / dim_ratio)))
463
-
464
- threads_i = min (sizes[1 ], max_threads_i)
465
- threads_j = min (sizes[2 ], max_threads_j)
466
- threads = (threads_i, threads_j)
467
- blocks_i = min (cld (sizes[1 ], threads_i), max_blocks_i)
468
- blocks_j = min (cld (sizes[2 ], threads_j), max_blocks_j)
469
- blocks = (blocks_i, blocks_j)
470
-
471
- kernel (C, A, B, m, n, p, q; threads= threads, blocks= blocks)
472
-
473
- return C
474
- end
475
-
476
- function LinearAlgebra. kron (A:: CuMatrix{TA} , B:: CuMatrix{TB} ) where {TA,TB}
477
- m, n = size (A)
478
- p, q = size (B)
479
-
480
- T = promote_type (TA, TB)
481
- C = similar (A, T, m* p, n* q)
482
-
483
- kron! (C, A, B)
484
- end
0 commit comments