Skip to content

Commit

Permalink
Enzyme: add make_zero of cuarrays (JuliaGPU#2600)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and THargreaves committed Jan 7, 2025
1 parent 319514a commit 3ff1de0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
5 changes: 3 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ steps:
build.message !~ /\[only/ && !build.pull_request.draft &&
build.message !~ /\[skip tests\]/ &&
build.message !~ /\[skip downstream\]/
timeout_in_minutes: 30
soft_fail: true
timeout_in_minutes: 60
soft_fail:
- exit_status: 3

- group: ":eyes: Special"
depends_on: "cuda"
Expand Down
53 changes: 53 additions & 0 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,59 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
return nothing
end

@inline function EnzymeCore.make_zero(
x::DenseCuArray{FT},
) where {FT<:AbstractFloat}
return Base.zero(x)
end
@inline function EnzymeCore.make_zero(
x::DenseCuArray{Complex{FT}},
) where {FT<:AbstractFloat}
return Base.zero(x)
end

@inline function EnzymeCore.make_zero(
::Type{CT},
seen::IdDict,
prev::CT,
::Val{copy_if_inactive} = Val(false),
)::CT where {copy_if_inactive, FT<:AbstractFloat, CT <: Union{DenseCuArray{FT},DenseCuArray{Complex{FT}}}}
if haskey(seen, prev)
return seen[prev]
end
newa = Base.zero(prev)
seen[prev] = newa
return newa
end

@inline function EnzymeCore.make_zero!(
prev::DenseCuArray{FT},
seen::ST,
)::Nothing where {FT<:AbstractFloat,ST}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(FT))
return nothing
end

@inline function EnzymeCore.make_zero!(
prev::DenseCuArray{Complex{FT}},
seen::ST,
)::Nothing where {FT<:AbstractFloat,ST}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(Complex{FT}))
return nothing
end

function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
Expand Down
9 changes: 9 additions & 0 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ using CUDA
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
end

@testset "Make_zero" begin
A = CUDA.ones(64)
dA = Enzyme.make_zero(A)
@test all(dA .≈ 0)
dA = CUDA.ones(64)
Enzyme.make_zero!(dA)
@test all(dA .≈ 0)
end

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
Expand Down

0 comments on commit 3ff1de0

Please sign in to comment.