Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GradCAM analyzer #155

Merged
merged 12 commits into from
Feb 8, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Currently, the following analyzers are implemented:
* `InputTimesGradient`
* `SmoothGrad`
* `IntegratedGradients`
* `GradCAM`

One of the design goals of the [Julia-XAI ecosystem][juliaxai-docs] is extensibility.
To implement an XAI method, take a look at the [common interface
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Gradient
InputTimesGradient
SmoothGrad
IntegratedGradients
GradCAM
```

# Input augmentations
Expand Down
2 changes: 2 additions & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ include("compat.jl")
include("bibliography.jl")
include("input_augmentation.jl")
include("gradient.jl")
include("gradcam.jl")

export Gradient, InputTimesGradient
export NoiseAugmentation, SmoothGrad
export InterpolationAugmentation, IntegratedGradients
export GradCAM

end # module
1 change: 1 addition & 0 deletions src/bibliography.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Gradient methods:
const REF_SMILKOV_SMOOTHGRAD = "Smilkov et al., *SmoothGrad: removing noise by adding noise*"
const REF_SUNDARARAJAN_AXIOMATIC = "Sundararajan et al., *Axiomatic Attribution for Deep Networks*"
const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*"
31 changes: 31 additions & 0 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
GradCAM(feature_layers, adaptation_layers)

Calculates the Gradient-weighted Class Activation Map (GradCAM).
GradCAM provides a visual explanation of the regions with significant neuron importance for the model's classification decision.

# Parameters
- `feature_layers`: The layers of a convolutional neural network (CNN) responsible for extracting feature maps.
- `adaptation_layers`: The layers of the CNN used for adaptation and classification.

# Note
Flux is not required for GradCAM.
GradCAM is compatible with a wide variety of CNN model-families.

# References
- $REF_SELVARAJU_GRADCAM
"""
struct GradCAM{F,A} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
end
function (analyzer::GradCAM)(input, ns::AbstractNeuronSelector)
A = analyzer.feature_layers(input) # feature map
feature_map_size = size(A, 1) * size(A, 2)

# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing)
end
Binary file added test/references/cnn/GradCAM_max.jld2
Binary file not shown.
Binary file added test/references/cnn/GradCAM_ns1.jld2
Binary file not shown.
3 changes: 2 additions & 1 deletion test/test_batches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ins = 20
outs = 10
batchsize = 15

model = Chain(Dense(ins, outs, relu; init=pseudorand))
model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand))

# Input 1 w/o batch dimension
input1_no_bd = rand(MersenneTwister(1), Float32, ins)
Expand All @@ -24,6 +24,7 @@ ANALYZERS = Dict(
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
"GradCAM" => m -> GradCAM(m[1], m[2]),
)

for (name, method) in ANALYZERS
Expand Down
5 changes: 1 addition & 4 deletions test/test_cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const GRADIENT_ANALYZERS = Dict(
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
"GradCAM" => m -> GradCAM(m[1], m[2]),
)

input_size = (32, 32, 3, 1)
Expand Down Expand Up @@ -67,17 +68,13 @@ function test_cnn(name, method)
println("Timing $name...")
print("cold:")
@time expl = analyze(input, analyzer)

@test size(expl.val) == size(input)
@test_reference "references/cnn/$(name)_max.jld2" Dict("expl" => expl.val) by =
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
end
@testset "Neuron selection" begin
analyzer = method(model)
print("warm:")
@time expl = analyze(input, analyzer, 1)

@test size(expl.val) == size(input)
@test_reference "references/cnn/$(name)_ns1.jld2" Dict("expl" => expl.val) by =
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
end
Expand Down
Loading