From db2aa157062c78672310330b9def606942591f93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Jan 2025 11:37:44 -0500 Subject: [PATCH] feat: add dispatch for KA get_backend (#645) * feat: add dispatch for KA get_backend * fix: missing test dep * test: run when cuda available --- Project.toml | 4 +- ext/ReactantCUDAExt.jl | 3 ++ test/Project.toml | 2 + test/integration/kernelabstractions.jl | 65 ++++++++++++++++++++------ test/runtests.jl | 1 + 5 files changed, 60 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 4d01781ce..d6f5d06a7 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -40,7 +41,7 @@ ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = "CUDA" +ReactantCUDAExt = ["CUDA", "KernelAbstractions"] ReactantNNlibExt = "NNlib" ReactantOffsetArraysExt = "OffsetArrays" ReactantPythonCallExt = "PythonCall" @@ -60,6 +61,7 @@ Enzyme = "0.13.28" EnzymeCore = "0.8.8" Functors = "0.5" GPUArraysCore = "0.1.6, 0.2" +KernelAbstractions = "0.9.30" LinearAlgebra = "1.10" NNlib = "0.9.26" OffsetArrays = "1" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 6feaf2f2d..6d863fc53 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -3,10 +3,13 @@ module ReactantCUDAExt using CUDA using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber using ReactantCore: @trace +using KernelAbstractions: KernelAbstractions using Libdl using Adapt +KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend() + struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} ptr::Core.LLVMPtr{T,A} diff --git a/test/Project.toml b/test/Project.toml index cbdd32c90..a0ecb42ce 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" @@ -43,6 +44,7 @@ Flux = "0.15, 0.16" Functors = "0.5" HypothesisTests = "0.11" InteractiveUtils = "1.10" +KernelAbstractions = "0.9.30" LinearAlgebra = "1.10" Lux = "1.4.1" LuxLib = "1.3" diff --git a/test/integration/kernelabstractions.jl b/test/integration/kernelabstractions.jl index af6c0e80b..7fa3bd6d9 100644 --- a/test/integration/kernelabstractions.jl +++ b/test/integration/kernelabstractions.jl @@ -1,8 +1,4 @@ -using CUDA -using KernelAbstractions -using Reactant - -using CUDA: CuArray +using CUDA, KernelAbstractions, Reactant # Simple kernel for matrix multiplication @kernel function matmul_kernel!(output, a) @@ -11,23 +7,64 @@ using CUDA: CuArray tmp_sum = zero(eltype(output)) for k in 1:size(a)[2] - tmp_sum += a[i, k] * a[k, j] + @inbounds tmp_sum += a[i, k] * a[k, j] end - return output[i, j] = tmp_sum + @inbounds output[i, j] = tmp_sum end # Creating a wrapper kernel for launching with error checks -function matmul!(output, a, backend) +function matmul!(output, a) + backend = KernelAbstractions.get_backend(output) kernel! = matmul_kernel!(backend) kernel!(output, a; ndrange=size(output)) return KernelAbstractions.synchronize(backend) end -@testset "KernelAbstractions Call" begin - backend = KernelAbstractions.get_backend(CuArray(ones(1))) - A = Reactant.to_rarray(CuArray(ones(100, 100))) - out = Reactant.to_rarray(CuArray(ones(100, 100))) - @jit matmul!(out, A, backend) - @test all(Array(out) .≈ 100) +# https://github.com/EnzymeAD/Reactant.jl/issues/614 +const skip_non_cuda_tests = true + +@static if !Sys.isapple() + @testset "KernelAbstractions Matmul" begin + A = Reactant.to_rarray(ones(100, 100)) + out = Reactant.to_rarray(ones(100, 100)) + if CUDA.functional() + @test all(Array(@jit(matmul!(out, A))) .≈ 100) broken = true + else + @static if skip_non_cuda_tests + @test false broken = true + else + @code_hlo optimize = :before_kernel matmul!(out, A) + end + end + end +end + +# simple square kernel +@kernel function square_kernel!(y, @Const(x)) + i = @index(Global) + @inbounds y[i] = x[i] * x[i] +end + +function square(x) + y = similar(x) + backend = KernelAbstractions.get_backend(x) + kernel! = square_kernel!(backend) + kernel!(y, x; ndrange=length(x)) + return y +end + +@static if !Sys.isapple() + @testset "KernelAbstractions Square" begin + x = Reactant.to_rarray(collect(1:1:64) ./ 64) + if CUDA.functional() + @test all(Array(@jit(square(x))) .≈ Array(x) .* Array(x)) + else + @static if skip_non_cuda_tests + @test false broken = true + else + @code_hlo optimize = :before_kernel square(x) + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index a65a301d9..6866fed21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,6 +65,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "CUDA" include("integration/cuda.jl") + @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "OffsetArrays" include("integration/offsetarrays.jl") @safetestset "AbstractFFTs" include("integration/fft.jl")