diff --git a/Project.toml b/Project.toml index 94083ad..46bd8bb 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ authors = ["Adrian Hill "] version = "0.8.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -12,6 +14,8 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1" +DifferentiationInterface = "0.5" Distributions = "0.25" Random = "<0.0.1, 1" Reexport = "1" diff --git a/src/ExplainableAI.jl b/src/ExplainableAI.jl index 2b33379..f999b15 100644 --- a/src/ExplainableAI.jl +++ b/src/ExplainableAI.jl @@ -7,7 +7,12 @@ import XAIBase: call_analyzer using Base.Iterators using Distributions: Distribution, Sampleable, Normal using Random: AbstractRNG, GLOBAL_RNG + +# Automatic differentiation +using ADTypes: AbstractADType, AutoZygote +using DifferentiationInterface: value_and_pullback using Zygote +const DEFAULT_AD_BACKEND = AutoZygote() include("compat.jl") include("bibliography.jl") diff --git a/src/gradcam.jl b/src/gradcam.jl index 421ff5f..2658a63 100644 --- a/src/gradcam.jl +++ b/src/gradcam.jl @@ -15,16 +15,25 @@ GradCAM is compatible with a wide variety of CNN model-families. # References - $REF_SELVARAJU_GRADCAM """ -struct GradCAM{F,A} <: AbstractXAIMethod +struct GradCAM{F,A,B<:AbstractADType} <: AbstractXAIMethod feature_layers::F adaptation_layers::A + backend::B + + function GradCAM( + feature_layers::F, adaptation_layers::A, backend::B=DEFAULT_AD_BACKEND + ) where {F,A,B<:AbstractADType} + new{F,A,B}(feature_layers, adaptation_layers, backend) + end end function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...) A = analyzer.feature_layers(input) # feature map feature_map_size = size(A, 1) * size(A, 2) # Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ - grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns) + grad, output, output_indices = gradient_wrt_input( + analyzer.adaptation_layers, A, ns, analyzer.backend + ) αᶜ = sum(grad; dims=(1, 2)) / feature_map_size Lᶜ = max.(sum(αᶜ .* A; dims=3), 0) return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing) diff --git a/src/gradient.jl b/src/gradient.jl index 5fea07b..58577f0 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -1,12 +1,25 @@ -function gradient_wrt_input(model, input, ns::AbstractOutputSelector) - output, back = Zygote.pullback(model, input) - output_indices = ns(output) - - # Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons - v = zero(output) - v[output_indices] .= 1 - grad = only(back(v)) - return grad, output, output_indices +function forward_with_output_selection(model, input, selector::AbstractOutputSelector) + output = model(input) + sel = selector(output) + return output[sel] +end + +function gradient_wrt_input( + model, input, output_selector::AbstractOutputSelector, backend::AbstractADType +) + output = model(input) + return gradient_wrt_input(model, input, output, output_selector, backend) +end + +function gradient_wrt_input( + model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType +) + output_selection = output_selector(output) + dy = zero(output) + dy[output_selection] .= 1 + + output, grad = value_and_pullback(model, backend, input, dy) + return grad, output, output_selection end """ @@ -14,13 +27,19 @@ end Analyze model by calculating the gradient of a neuron activation with respect to the input. """ -struct Gradient{M} <: AbstractXAIMethod +struct Gradient{M,B<:AbstractADType} <: AbstractXAIMethod model::M - Gradient(model) = new{typeof(model)}(model) + backend::B + + function Gradient(model::M, backend::B=DEFAULT_AD_BACKEND) where {M,B<:AbstractADType} + new{M,B}(model, backend) + end end function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...) - grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns) + grad, output, output_indices = gradient_wrt_input( + analyzer.model, input, ns, analyzer.backend + ) return Explanation( grad, input, output, output_indices, :Gradient, :sensitivity, nothing ) @@ -32,15 +51,23 @@ end Analyze model by calculating the gradient of a neuron activation with respect to the input. This gradient is then multiplied element-wise with the input. """ -struct InputTimesGradient{M} <: AbstractXAIMethod +struct InputTimesGradient{M,B<:AbstractADType} <: AbstractXAIMethod model::M - InputTimesGradient(model) = new{typeof(model)}(model) + backend::B + + function InputTimesGradient( + model::M, backend::B=DEFAULT_AD_BACKEND + ) where {M,B<:AbstractADType} + new{M,B}(model, backend) + end end function call_analyzer( input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs... ) - grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns) + grad, output, output_indices = gradient_wrt_input( + analyzer.model, input, ns, analyzer.backend + ) attr = input .* grad return Explanation( attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing