Skip to content

Commit

Permalink
feat: add dispatch for KA get_backend (#645)
Browse files Browse the repository at this point in the history
* feat: add dispatch for KA get_backend

* fix: missing test dep

* test: run when cuda available
  • Loading branch information
avik-pal authored Jan 28, 2025
1 parent fd60aad commit db2aa15
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 15 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand All @@ -40,7 +41,7 @@ ReactantCore = {path = "lib/ReactantCore"}
[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantCUDAExt = ["CUDA", "KernelAbstractions"]
ReactantNNlibExt = "NNlib"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantPythonCallExt = "PythonCall"
Expand All @@ -60,6 +61,7 @@ Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
Functors = "0.5"
GPUArraysCore = "0.1.6, 0.2"
KernelAbstractions = "0.9.30"
LinearAlgebra = "1.10"
NNlib = "0.9.26"
OffsetArrays = "1"
Expand Down
3 changes: 3 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ module ReactantCUDAExt
using CUDA
using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
using ReactantCore: @trace
using KernelAbstractions: KernelAbstractions
using Libdl

using Adapt

KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand Down Expand Up @@ -43,6 +44,7 @@ Flux = "0.15, 0.16"
Functors = "0.5"
HypothesisTests = "0.11"
InteractiveUtils = "1.10"
KernelAbstractions = "0.9.30"
LinearAlgebra = "1.10"
Lux = "1.4.1"
LuxLib = "1.3"
Expand Down
65 changes: 51 additions & 14 deletions test/integration/kernelabstractions.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using CUDA
using KernelAbstractions
using Reactant

using CUDA: CuArray
using CUDA, KernelAbstractions, Reactant

# Simple kernel for matrix multiplication
@kernel function matmul_kernel!(output, a)
Expand All @@ -11,23 +7,64 @@ using CUDA: CuArray

tmp_sum = zero(eltype(output))
for k in 1:size(a)[2]
tmp_sum += a[i, k] * a[k, j]
@inbounds tmp_sum += a[i, k] * a[k, j]
end

return output[i, j] = tmp_sum
@inbounds output[i, j] = tmp_sum
end

# Creating a wrapper kernel for launching with error checks
function matmul!(output, a, backend)
function matmul!(output, a)
backend = KernelAbstractions.get_backend(output)
kernel! = matmul_kernel!(backend)
kernel!(output, a; ndrange=size(output))
return KernelAbstractions.synchronize(backend)
end

@testset "KernelAbstractions Call" begin
backend = KernelAbstractions.get_backend(CuArray(ones(1)))
A = Reactant.to_rarray(CuArray(ones(100, 100)))
out = Reactant.to_rarray(CuArray(ones(100, 100)))
@jit matmul!(out, A, backend)
@test all(Array(out) .≈ 100)
# https://github.com/EnzymeAD/Reactant.jl/issues/614
const skip_non_cuda_tests = true

@static if !Sys.isapple()
@testset "KernelAbstractions Matmul" begin
A = Reactant.to_rarray(ones(100, 100))
out = Reactant.to_rarray(ones(100, 100))
if CUDA.functional()
@test all(Array(@jit(matmul!(out, A))) .≈ 100) broken = true
else
@static if skip_non_cuda_tests
@test false broken = true
else
@code_hlo optimize = :before_kernel matmul!(out, A)
end
end
end
end

# simple square kernel
@kernel function square_kernel!(y, @Const(x))
i = @index(Global)
@inbounds y[i] = x[i] * x[i]
end

function square(x)
y = similar(x)
backend = KernelAbstractions.get_backend(x)
kernel! = square_kernel!(backend)
kernel!(y, x; ndrange=length(x))
return y
end

@static if !Sys.isapple()
@testset "KernelAbstractions Square" begin
x = Reactant.to_rarray(collect(1:1:64) ./ 64)
if CUDA.functional()
@test all(Array(@jit(square(x))) .≈ Array(x) .* Array(x))
else
@static if skip_non_cuda_tests
@test false broken = true
else
@code_hlo optimize = :before_kernel square(x)
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "CUDA" include("integration/cuda.jl")
@safetestset "KernelAbstractions" include("integration/kernelabstractions.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "OffsetArrays" include("integration/offsetarrays.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
Expand Down

0 comments on commit db2aa15

Please sign in to comment.