From 974a934f56bb651b8b0fd3f3c1d71ab03047d278 Mon Sep 17 00:00:00 2001 From: Janes Sanne <59392839+JeanAnNess@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:43:20 +0100 Subject: [PATCH] Add GradCAM analyzer (#155) --- README.md | 1 + docs/src/api.md | 1 + src/ExplainableAI.jl | 2 ++ src/bibliography.jl | 1 + src/gradcam.jl | 31 +++++++++++++++++++++++++++ test/references/cnn/GradCAM_max.jld2 | Bin 0 -> 1057 bytes test/references/cnn/GradCAM_ns1.jld2 | Bin 0 -> 1057 bytes test/test_batches.jl | 3 ++- test/test_cnn.jl | 5 +---- 9 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 src/gradcam.jl create mode 100644 test/references/cnn/GradCAM_max.jld2 create mode 100644 test/references/cnn/GradCAM_ns1.jld2 diff --git a/README.md b/README.md index d7863d9..9a04f2c 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Currently, the following analyzers are implemented: * `InputTimesGradient` * `SmoothGrad` * `IntegratedGradients` +* `GradCAM` One of the design goals of the [Julia-XAI ecosystem][juliaxai-docs] is extensibility. To implement an XAI method, take a look at the [common interface diff --git a/docs/src/api.md b/docs/src/api.md index 31269f5..dad40fe 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -12,6 +12,7 @@ Gradient InputTimesGradient SmoothGrad IntegratedGradients +GradCAM ``` # Input augmentations diff --git a/src/ExplainableAI.jl b/src/ExplainableAI.jl index 039593e..19e3149 100644 --- a/src/ExplainableAI.jl +++ b/src/ExplainableAI.jl @@ -12,9 +12,11 @@ include("compat.jl") include("bibliography.jl") include("input_augmentation.jl") include("gradient.jl") +include("gradcam.jl") export Gradient, InputTimesGradient export NoiseAugmentation, SmoothGrad export InterpolationAugmentation, IntegratedGradients +export GradCAM end # module diff --git a/src/bibliography.jl b/src/bibliography.jl index 4e88e35..83919da 100644 --- a/src/bibliography.jl +++ b/src/bibliography.jl @@ -1,3 +1,4 @@ # Gradient methods: const REF_SMILKOV_SMOOTHGRAD = "Smilkov et al., *SmoothGrad: removing noise by adding noise*" const REF_SUNDARARAJAN_AXIOMATIC = "Sundararajan et al., *Axiomatic Attribution for Deep Networks*" +const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*" \ No newline at end of file diff --git a/src/gradcam.jl b/src/gradcam.jl new file mode 100644 index 0000000..c44ac86 --- /dev/null +++ b/src/gradcam.jl @@ -0,0 +1,31 @@ +""" + GradCAM(feature_layers, adaptation_layers) + +Calculates the Gradient-weighted Class Activation Map (GradCAM). +GradCAM provides a visual explanation of the regions with significant neuron importance for the model's classification decision. + +# Parameters +- `feature_layers`: The layers of a convolutional neural network (CNN) responsible for extracting feature maps. +- `adaptation_layers`: The layers of the CNN used for adaptation and classification. + +# Note +Flux is not required for GradCAM. +GradCAM is compatible with a wide variety of CNN model-families. + +# References +- $REF_SELVARAJU_GRADCAM +""" +struct GradCAM{F,A} <: AbstractXAIMethod + feature_layers::F + adaptation_layers::A +end +function (analyzer::GradCAM)(input, ns::AbstractNeuronSelector) + A = analyzer.feature_layers(input) # feature map + feature_map_size = size(A, 1) * size(A, 2) + + # Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ + grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns) + αᶜ = sum(grad; dims=(1, 2)) / feature_map_size + Lᶜ = max.(sum(αᶜ .* A; dims=3), 0) + return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing) +end diff --git a/test/references/cnn/GradCAM_max.jld2 b/test/references/cnn/GradCAM_max.jld2 new file mode 100644 index 0000000000000000000000000000000000000000..5031171ba544469025d62cb5495ddaba7f2d8d05 GIT binary patch literal 1057 zcmeZpaWmCTN-R!IQSd6w$xKvmNi0cJaLX^sO)Sw-C`&CW&dkqKFwis9Gh|TEfG7d7 z4fG5Y%uIBXGD{SETs0X+!B7o>P7fD1UM?vvCJqh;1}2Cv{zHKx3xwar2%(E>57zm6 zxCAjV`Z2Ns)iQH3ssJ^yFfc+DFu-V_SzrL8nMD{F4He`WAOurPwe z7!!46t+spWSe;1vVwsnmWW`a!YNhqZ*lNy~vzCkOHdt0Y@wIANmTblMEyRkuti(#6 zsm;pWw8%>P&UMS%rNUMPZIi6FNb6dKuyt8^?{KqH>G!r`u>WWIS0wz{yg#_IF-1gm*_ zQmt+?v{)?&UT>N3f4SAXb7@uuWfv{$WJ0X;BNtiNZ&tUOv#rxA>+m$o0+X3moYy0* za?;zaG{VkWHnT0UGIG(d(mLB^b^FyKiyodvs|)@WR$iHJEVa9~TCudBv)p?@(aM#1 zpOw~EA*)v_&%OYMEJF{I5GdqeQ2?X4z)`~{01PVuMwZlyf*b>I00RjZ7=uC@Xf`^i Im?Gc~00jr6KL7v# literal 0 HcmV?d00001 diff --git a/test/references/cnn/GradCAM_ns1.jld2 b/test/references/cnn/GradCAM_ns1.jld2 new file mode 100644 index 0000000000000000000000000000000000000000..d42a67b002e04887124d6369f6eb00e4cae481d0 GIT binary patch literal 1057 zcmeZpaWmCTN-R!IQSd6w$xKvmNi0cJaLX^sO)Sw-C`&CW&dkqKFwis9Gh|TEfG7d7 z4fG5Y%uIBXGD{SETs0X+!B7o>P7fD1UM?vvCJqh;1}2Cv{zHKx3xwar2%(E>57zm6 zxCAjV`Z2Ns)iQH3ssJ^yFfc+DFu-V_SzrL8nMD{F4He`WAOurPwe zVA`O%Mjq8Hazh>*vJ5>;LLl$Mq5wv7fun{^02o#Rj4Y`Y1vv(=P;!AWD5QaAql1bm G0`36lm^%Of literal 0 HcmV?d00001 diff --git a/test/test_batches.jl b/test/test_batches.jl index 7922c8c..b6d6f3a 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -8,7 +8,7 @@ ins = 20 outs = 10 batchsize = 15 -model = Chain(Dense(ins, outs, relu; init=pseudorand)) +model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand)) # Input 1 w/o batch dimension input1_no_bd = rand(MersenneTwister(1), Float32, ins) @@ -24,6 +24,7 @@ ANALYZERS = Dict( "InputTimesGradient" => InputTimesGradient, "SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)), "IntegratedGradients" => m -> IntegratedGradients(m, 5), + "GradCAM" => m -> GradCAM(m[1], m[2]), ) for (name, method) in ANALYZERS diff --git a/test/test_cnn.jl b/test/test_cnn.jl index ad42d98..3dc5602 100644 --- a/test/test_cnn.jl +++ b/test/test_cnn.jl @@ -6,6 +6,7 @@ const GRADIENT_ANALYZERS = Dict( "InputTimesGradient" => InputTimesGradient, "SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)), "IntegratedGradients" => m -> IntegratedGradients(m, 5), + "GradCAM" => m -> GradCAM(m[1], m[2]), ) input_size = (32, 32, 3, 1) @@ -67,8 +68,6 @@ function test_cnn(name, method) println("Timing $name...") print("cold:") @time expl = analyze(input, analyzer) - - @test size(expl.val) == size(input) @test_reference "references/cnn/$(name)_max.jld2" Dict("expl" => expl.val) by = (r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05) end @@ -76,8 +75,6 @@ function test_cnn(name, method) analyzer = method(model) print("warm:") @time expl = analyze(input, analyzer, 1) - - @test size(expl.val) == size(input) @test_reference "references/cnn/$(name)_ns1.jld2" Dict("expl" => expl.val) by = (r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05) end