Skip to content

Commit 895cfa5

Browse files
authored
Added wrapped C cuda code and runable examples (#1)
initial
1 parent 7f35b0b commit 895cfa5

13 files changed

+1508
-23
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
*.jl.cov
33
*.jl.mem
44
/Manifest.toml
5+
lib
6+
.vscode

Artifacts.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[CUDA_lib]
2+
git-tree-sha1 = "2918fba865582556e219191a7f393c47c2e822e0"
3+
4+
[[CUDA_lib.download]]
5+
sha256 = "751bf9d1f2921d4176ffb8ed1ddbd59bb60d6a517e6784bb71d61b62357c0007"
6+
url = "https://gist.github.com/ArrogantGao/c38791f143d36d4b2481ac7e4aa4ecce/raw/2918fba865582556e219191a7f393c47c2e822e0.tar.gz"

LICENSE

Lines changed: 674 additions & 21 deletions
Large diffs are not rendered by default.

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ uuid = "c2b282c3-c9c2-431d-80f7-a1a0561ebe55"
33
authors = ["Xuanzhao Gao <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

6+
[deps]
7+
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
8+
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
11+
612
[compat]
713
julia = "1"
814

benchmark/benchmark_CUDA_mapreduce.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using TropicalNumbers, CUDA, BenchmarkTools
2+
3+
function map_reduce_benchmark(m::T, n::T, k::T) where{T}
4+
A = Tropical.(CUDA.randn(Float32, (m, k)))
5+
B = Tropical.(CUDA.randn(Float32, (k, n)))
6+
C = Tropical.(CUDA.randn(Float32, (k, n)))
7+
8+
elapsed_time = @belapsed CUDA.@sync begin
9+
$C = $A * $B
10+
end
11+
12+
work_load = 2 * m * n * k
13+
flops = work_load / elapsed_time / 1e9
14+
@show m, n, k, elapsed_time, flops
15+
return nothing
16+
end
17+
18+
map_reduce_benchmark(2560, 2048, 2048)
19+
map_reduce_benchmark(2 * 2560, 2 * 2048, 2 * 2048)
20+
map_reduce_benchmark(4 * 2560, 4 * 2048, 4 * 2048)

benchmark/benchmark_CuTropicalGemm.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using CUDA
2+
using BenchmarkTools
3+
using CuTropicalGEMM
4+
5+
function benchmakr_CuTropicalGemmFP32(m::T, n::T, k::T) where{T}
6+
A = rand(Float32, (m * k))
7+
B = rand(Float32, (k * n))
8+
C = rand(Float32, (m * n))
9+
10+
CuA = CuArray(A)
11+
CuB = CuArray(B)
12+
CuC = CuArray(C)
13+
14+
# I found hat @belapsed and CUDA.@sync can not properly benchmark the function from .so lib
15+
elapsed_time = @belapsed CUDA.@sync begin
16+
1 + 1
17+
CuTropicalGemmMatmulFP32!($m, $n, $k, $CuA, $CuB, $CuC)
18+
1 + 1
19+
end
20+
21+
work_load = 2 * m * n * k
22+
flops = work_load / elapsed_time / 1e9
23+
@show m, n, k, elapsed_time, flops
24+
return nothing
25+
end
26+
27+
benchmakr_CuTropicalGemmFP32(2560, 2048, 2048)
28+
benchmakr_CuTropicalGemmFP32(2 * 2560, 2 * 2048, 2 * 2048)
29+
benchmakr_CuTropicalGemmFP32(4 * 2560, 4 * 2048, 4 * 2048)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#! /bin/bash
2+
3+
nvcc -arch=sm_80 ../src/TropicalSGemmFP32.cu
4+
./a.out
5+
6+
rm a.out
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# notice: this code is used to benchmark the Tropical matmul in GemmKernels.jl, which is not yet released in the latest version and only supported in [email protected]
2+
# to run the code, you need to manually download the latest version repo of GemmKernels.jl and activate the enviroment
3+
4+
import CUDA
5+
import InteractiveUtils
6+
7+
using CUDA
8+
using GemmKernels
9+
using LinearAlgebra
10+
using BenchmarkTools
11+
using Test
12+
13+
CUDA.allowscalar(false)
14+
15+
function try_tropical(M, N, K)
16+
for (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)],
17+
transpose_a = [true, false],
18+
transpose_b = [true, false],
19+
(OP_M, OP_N, OP_K) in [(8, 16, 2)]
20+
21+
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
22+
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
23+
c_h = rand(CD_type, (M, N))
24+
d_h = similar(c_h)
25+
26+
27+
# Transpose input if necessary
28+
a_h = transpose_a ? transpose(a_h) : a_h
29+
b_h = transpose_b ? transpose(b_h) : b_h
30+
31+
a = CuArray(a_h)
32+
b = CuArray(b_h)
33+
c = CuArray(c_h)
34+
d = similar(c)
35+
36+
conf = GemmKernels.get_config(
37+
gemm_shape = (M = M, N = N, K = K),
38+
block_shape = (M = 64, N = 64, K = 32),
39+
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, CD_type, A_type},
40+
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
41+
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},
42+
43+
global_c_layout = Layout.AlignedColMajor{CD_type},
44+
global_d_layout = Layout.AlignedColMajor{CD_type},
45+
46+
is_a_col_major = !transpose_a,
47+
is_b_col_major = !transpose_b,
48+
)
49+
50+
n_iter = 1
51+
elapsed_time = @belapsed CUDA.@sync begin
52+
GemmKernels.matmul($a, $b, $c, $d, $conf; kernel = Kernel.matmul_pipelined)
53+
end
54+
TFlops = (n_iter * M * N * K * 2 / elapsed_time) / 1e9
55+
@show TFlops, elapsed_time, transpose_a, transpose_b, M, N, K
56+
57+
58+
d_c = Array(d)
59+
60+
# random 1600 points took to test
61+
if transpose_a == transpose_b == false
62+
@testset begin
63+
for _ in 1 : 40
64+
for _ in 1 : 40
65+
i = rand(1:M)
66+
j = rand(1:N)
67+
d_h[i, j] = c_h[i, j]
68+
for k in 1 : K
69+
d_h[i, j] = max(a_h[i, k] + b_h[k, j], d_h[i, j])
70+
end
71+
@test isapprox(d_h[i, j], d_c[i, j]; rtol = sqrt(eps(A_type)))
72+
end
73+
end
74+
end
75+
end
76+
end
77+
return nothing
78+
end
79+
80+
81+
try_tropical(2560, 2048, 2048)
82+
try_tropical(2 * 2560, 2 * 2048, 2 * 2048)
83+
try_tropical(4 * 2560, 4 * 2048, 4 * 2048)

src/CuTropicalGEMM.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
module CuTropicalGEMM
22

3-
# Write your package code here.
3+
export CuTropicalGemmMatmulFP32!
4+
5+
using CUDA
6+
using Artifacts
7+
8+
include("TropicalGemm_Cuda_wrapper.jl")
49

510
end

0 commit comments

Comments
 (0)