Skip to content

Commit

Permalink
Add GradCAM analyzer (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanAnNess authored Feb 8, 2024
1 parent ac1356d commit 974a934
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Currently, the following analyzers are implemented:
* `InputTimesGradient`
* `SmoothGrad`
* `IntegratedGradients`
* `GradCAM`

One of the design goals of the [Julia-XAI ecosystem][juliaxai-docs] is extensibility.
To implement an XAI method, take a look at the [common interface
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Gradient
InputTimesGradient
SmoothGrad
IntegratedGradients
GradCAM
```

# Input augmentations
Expand Down
2 changes: 2 additions & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ include("compat.jl")
include("bibliography.jl")
include("input_augmentation.jl")
include("gradient.jl")
include("gradcam.jl")

export Gradient, InputTimesGradient
export NoiseAugmentation, SmoothGrad
export InterpolationAugmentation, IntegratedGradients
export GradCAM

end # module
1 change: 1 addition & 0 deletions src/bibliography.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Gradient methods:
const REF_SMILKOV_SMOOTHGRAD = "Smilkov et al., *SmoothGrad: removing noise by adding noise*"
const REF_SUNDARARAJAN_AXIOMATIC = "Sundararajan et al., *Axiomatic Attribution for Deep Networks*"
const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*"
31 changes: 31 additions & 0 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
GradCAM(feature_layers, adaptation_layers)
Calculates the Gradient-weighted Class Activation Map (GradCAM).
GradCAM provides a visual explanation of the regions with significant neuron importance for the model's classification decision.
# Parameters
- `feature_layers`: The layers of a convolutional neural network (CNN) responsible for extracting feature maps.
- `adaptation_layers`: The layers of the CNN used for adaptation and classification.
# Note
Flux is not required for GradCAM.
GradCAM is compatible with a wide variety of CNN model-families.
# References
- $REF_SELVARAJU_GRADCAM
"""
struct GradCAM{F,A} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
end
function (analyzer::GradCAM)(input, ns::AbstractNeuronSelector)
A = analyzer.feature_layers(input) # feature map
feature_map_size = size(A, 1) * size(A, 2)

# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing)
end
Binary file added test/references/cnn/GradCAM_max.jld2
Binary file not shown.
Binary file added test/references/cnn/GradCAM_ns1.jld2
Binary file not shown.
3 changes: 2 additions & 1 deletion test/test_batches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ins = 20
outs = 10
batchsize = 15

model = Chain(Dense(ins, outs, relu; init=pseudorand))
model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand))

# Input 1 w/o batch dimension
input1_no_bd = rand(MersenneTwister(1), Float32, ins)
Expand All @@ -24,6 +24,7 @@ ANALYZERS = Dict(
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
"GradCAM" => m -> GradCAM(m[1], m[2]),
)

for (name, method) in ANALYZERS
Expand Down
5 changes: 1 addition & 4 deletions test/test_cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const GRADIENT_ANALYZERS = Dict(
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
"GradCAM" => m -> GradCAM(m[1], m[2]),
)

input_size = (32, 32, 3, 1)
Expand Down Expand Up @@ -67,17 +68,13 @@ function test_cnn(name, method)
println("Timing $name...")
print("cold:")
@time expl = analyze(input, analyzer)

@test size(expl.val) == size(input)
@test_reference "references/cnn/$(name)_max.jld2" Dict("expl" => expl.val) by =
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
end
@testset "Neuron selection" begin
analyzer = method(model)
print("warm:")
@time expl = analyze(input, analyzer, 1)

@test size(expl.val) == size(input)
@test_reference "references/cnn/$(name)_ns1.jld2" Dict("expl" => expl.val) by =
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
end
Expand Down

0 comments on commit 974a934

Please sign in to comment.