From 3ff1de0bf4439508705a9081f47e6c5ff8fd8c00 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 03:09:55 -0500 Subject: [PATCH] Enzyme: add make_zero of cuarrays (#2600) --- .buildkite/pipeline.yml | 5 ++-- ext/EnzymeCoreExt.jl | 53 +++++++++++++++++++++++++++++++++++++++ test/extensions/enzyme.jl | 9 +++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1db6e05248..9196a4ef5f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -246,8 +246,9 @@ steps: build.message !~ /\[only/ && !build.pull_request.draft && build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 30 - soft_fail: true + timeout_in_minutes: 60 + soft_fail: + - exit_status: 3 - group: ":eyes: Special" depends_on: "cuda" diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index fb1003eaf9..c5bdadc4cd 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -555,6 +555,59 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...) return nothing end +@inline function EnzymeCore.make_zero( + x::DenseCuArray{FT}, +) where {FT<:AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::DenseCuArray{Complex{FT}}, +) where {FT<:AbstractFloat} + return Base.zero(x) +end + +@inline function EnzymeCore.make_zero( + ::Type{CT}, + seen::IdDict, + prev::CT, + ::Val{copy_if_inactive} = Val(false), +)::CT where {copy_if_inactive, FT<:AbstractFloat, CT <: Union{DenseCuArray{FT},DenseCuArray{Complex{FT}}}} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end + +@inline function EnzymeCore.make_zero!( + prev::DenseCuArray{FT}, + seen::ST, +)::Nothing where {FT<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(FT)) + return nothing +end + +@inline function EnzymeCore.make_zero!( + prev::DenseCuArray{Complex{FT}}, + seen::ST, +)::Nothing where {FT<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(Complex{FT})) + return nothing +end + function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)}, ::Type{RT}, f::EnzymeCore.Const{typeof(Base.identity)}, diff --git a/test/extensions/enzyme.jl b/test/extensions/enzyme.jl index 3e34af70a5..75f452d36b 100644 --- a/test/extensions/enzyme.jl +++ b/test/extensions/enzyme.jl @@ -7,6 +7,15 @@ using CUDA @test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob end +@testset "Make_zero" begin + A = CUDA.ones(64) + dA = Enzyme.make_zero(A) + @test all(dA .≈ 0) + dA = CUDA.ones(64) + Enzyme.make_zero!(dA) + @test all(dA .≈ 0) +end + function square_kernel!(x) i = threadIdx().x x[i] *= x[i]