Skip to content

Commit 031d7b9

Browse files
authored
Remove kron methods and use those in GPUArrays (#2643)
1 parent deb38b1 commit 031d7b9

File tree

1 file changed

+0
-86
lines changed

1 file changed

+0
-86
lines changed

lib/cublas/linalg.jl

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -396,89 +396,3 @@ for op in (:(+), :(-))
396396
@eval Base.$op(A::$TypeA, B::$TypeB) where {T <: CublasFloat} = geam($transa(T), $transb(T), one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)))
397397
end
398398
end
399-
400-
# Kronecker product
401-
function LinearAlgebra.kron!(C::CuMatrix{TC}, A::CuMatrix{TA}, B::CuMatrix{TB}) where {TA,TB,TC}
402-
403-
function _kron_mat_kernelA!(C, A, B, m, n, p, q)
404-
index_i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
405-
index_j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
406-
407-
stride_i = blockDim().x * gridDim().x
408-
stride_j = blockDim().y * gridDim().y
409-
410-
index_i > m && return
411-
index_j > n && return
412-
413-
for i in index_i:stride_i:m
414-
for j in index_j:stride_j:n
415-
for k in 1:p
416-
for l in 1:q
417-
@inbounds C[(i-1)*p+k, (j-1)*q+l] = A[i,j] * B[k,l]
418-
end
419-
end
420-
end
421-
end
422-
return nothing
423-
end
424-
425-
function _kron_mat_kernelB!(C, A, B, m, n, p, q)
426-
index_p = (blockIdx().x - 1) * blockDim().x + threadIdx().x
427-
index_q = (blockIdx().y - 1) * blockDim().y + threadIdx().y
428-
429-
stride_p = blockDim().x * gridDim().x
430-
stride_q = blockDim().y * gridDim().y
431-
432-
index_p > p && return
433-
index_q > q && return
434-
435-
for i in 1:m
436-
for j in 1:n
437-
for k in index_p:stride_p:p
438-
for l in index_q:stride_q:q
439-
@inbounds C[(i-1)*p+k, (j-1)*q+l] = A[i,j] * B[k,l]
440-
end
441-
end
442-
end
443-
end
444-
return nothing
445-
end
446-
447-
m, n = size(A)
448-
p, q = size(B)
449-
450-
# Use different kernels depending on the size of the matrices
451-
# choosing to parallelize the matrix with the largest number of elements
452-
m*n >= p*q ? (kernel = @cuda launch=false _kron_mat_kernelA!(C, A, B, m, n, p, q)) :
453-
(kernel = @cuda launch=false _kron_mat_kernelB!(C, A, B, m, n, p, q))
454-
455-
m*n >= p*q ? (sizes = (m, n)) : (sizes = (p, q))
456-
457-
config = launch_configuration(kernel.fun)
458-
dim_ratio = sizes[1] / sizes[2]
459-
max_threads_i = max(1, floor(Int, sqrt(config.threads * dim_ratio)))
460-
max_threads_j = max(1, floor(Int, sqrt(config.threads / dim_ratio)))
461-
max_blocks_i = max(1, floor(Int, sqrt(config.blocks * dim_ratio)))
462-
max_blocks_j = max(1, floor(Int, sqrt(config.blocks / dim_ratio)))
463-
464-
threads_i = min(sizes[1], max_threads_i)
465-
threads_j = min(sizes[2], max_threads_j)
466-
threads = (threads_i, threads_j)
467-
blocks_i = min(cld(sizes[1], threads_i), max_blocks_i)
468-
blocks_j = min(cld(sizes[2], threads_j), max_blocks_j)
469-
blocks = (blocks_i, blocks_j)
470-
471-
kernel(C, A, B, m, n, p, q; threads=threads, blocks=blocks)
472-
473-
return C
474-
end
475-
476-
function LinearAlgebra.kron(A::CuMatrix{TA}, B::CuMatrix{TB}) where {TA,TB}
477-
m, n = size(A)
478-
p, q = size(B)
479-
480-
T = promote_type(TA, TB)
481-
C = similar(A, T, m*p, n*q)
482-
483-
kron!(C, A, B)
484-
end

0 commit comments

Comments
 (0)