From 7db1fac089c4e46d638ca11515f045176e865904 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 18:37:39 -0700 Subject: [PATCH] feat: catch scalar indexing failures early --- Project.toml | 2 ++ ext/NNlibCUDAExt/utils.jl | 7 +++++++ src/NNlib.jl | 1 + src/conv.jl | 6 +++++- src/utils.jl | 8 ++++++++ 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 60e2efada..ec0af5b48 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.9.20" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -31,6 +32,7 @@ NNlibFFTWExt = "FFTW" [compat] AMDGPU = "0.9.4" Adapt = "3.2, 4" +ArrayInterface = "7.10" Atomix = "0.1" CUDA = "4, 5" ChainRulesCore = "1.13" diff --git a/ext/NNlibCUDAExt/utils.jl b/ext/NNlibCUDAExt/utils.jl index a9eaa8dbe..49739fda2 100644 --- a/ext/NNlibCUDAExt/utils.jl +++ b/ext/NNlibCUDAExt/utils.jl @@ -34,3 +34,10 @@ function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N NNlib.reverse_indices!(rev, idx) return map(cu, rev) end + +for op in (:conv!, :∇conv_data!, :∇conv_filter!, :depthwiseconv!, :∇depthwiseconv_data!, :∇depthwiseconv_filter!) + error_msg = "`$(op)` requires all arguments to support fast scalar indexing. You might be missing an `using cuDNN` or `import cuDNN` statement." + @eval function NNlib.special_scalar_indexing_error(::Val{$(Meta.quot(op))}, ::CUDA.AnyCuArray) + throw(AssertionError($(error_msg))) + end +end diff --git a/src/NNlib.jl b/src/NNlib.jl index 8cf66370f..3244e7a8c 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -3,6 +3,7 @@ module NNlib import Atomix import ChainRulesCore: rrule +using ArrayInterface: ArrayInterface using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore diff --git a/src/conv.jl b/src/conv.jl index 3fecb9151..28a6a869e 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -192,6 +192,7 @@ for (front_name, backend, signature) in ( @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end + assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2) x_cs = Iterators.partition(1:size(in1, 4), channels_in(cdims) ÷ groupcount(cdims)) @@ -233,7 +234,7 @@ for (front_name, backend, signature) in ( @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end - + assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2) dx_cs = Iterators.partition(1:size(out, 4), channels_in(cdims) ÷ groupcount(cdims)) @@ -276,6 +277,7 @@ for (front_name, backend, signature) in ( @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end + assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2) dw_cs = Iterators.partition(1:size(out, 5), channels_out(cdims) ÷ groupcount(cdims)) @@ -327,6 +329,8 @@ for (front_name, backend, signature) in ( @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end + assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2) + $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...) end end diff --git a/src/utils.jl b/src/utils.jl index 3d23e7383..2be2d0f8e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -162,3 +162,11 @@ if VERSION < v"1.7.0-DEV.793" end end +function assert_all_fast_scalar_indexing(call::Val{C}, args::AbstractArray...) where {C} + if !all(ArrayInterface.fast_scalar_indexing, args) + foreach(Base.Fix1(special_scalar_indexing_error, call), args) + throw(AssertionError("`$(C)` requires all arguments to support fast scalar indexing")) + end +end + +special_scalar_indexing_error(::Val, ::AbstractArray) = nothing