Skip to content

Commit

Permalink
feat: catch scalar indexing failures early
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 16, 2024
1 parent a4111c1 commit af56aea
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.9.20"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand Down Expand Up @@ -31,6 +32,7 @@ NNlibFFTWExt = "FFTW"
[compat]
AMDGPU = "0.9.4"
Adapt = "3.2, 4"
ArrayInterface = "7.12.0"
Atomix = "0.1"
CUDA = "4, 5"
ChainRulesCore = "1.13"
Expand Down
7 changes: 7 additions & 0 deletions ext/NNlibCUDAExt/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
NNlib.reverse_indices!(rev, idx)
return map(cu, rev)
end

for op in (:conv!, :∇conv_data!, :∇conv_filter!, :depthwiseconv!, :∇depthwiseconv_data!, :∇depthwiseconv_filter!)
error_msg = "`$(op)` requires all arguments to support fast scalar indexing. You might be missing an `using cuDNN` or `import cuDNN` statement."
@eval function NNlib.special_scalar_indexing_error(::Val{$(Meta.quot(op))}, ::CUDA.AnyCuArray)
throw(AssertionError($(error_msg)))
end
end
1 change: 1 addition & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module NNlib
import Atomix
import ChainRulesCore: rrule

using ArrayInterface: ArrayInterface
using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
Expand Down
6 changes: 5 additions & 1 deletion src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ for (front_name, backend, signature) in (
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)

x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
Expand Down Expand Up @@ -233,7 +234,7 @@ for (front_name, backend, signature) in (
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)

dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
Expand Down Expand Up @@ -276,6 +277,7 @@ for (front_name, backend, signature) in (
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)

dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
Expand Down Expand Up @@ -327,6 +329,8 @@ for (front_name, backend, signature) in (
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)

$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
end
end
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ if VERSION < v"1.7.0-DEV.793"
end
end

function assert_all_fast_scalar_indexing(call::Val{C}, args::AbstractArray...) where {C}
if !all(ArrayInterface.fast_scalar_indexing, args)
foreach(Base.Fix1(special_scalar_indexing_error, call), args)
throw(AssertionError("`$(C)` requires all arguments to support fast scalar indexing"))
end
end

special_scalar_indexing_error(::Val, ::AbstractArray) = nothing

0 comments on commit af56aea

Please sign in to comment.