Skip to content

Commit

Permalink
Support nested Flux Chains in LRP (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Aug 21, 2023
1 parent 10a5e0b commit 6e12cd2
Show file tree
Hide file tree
Showing 37 changed files with 977 additions and 486 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Expand All @@ -23,7 +22,7 @@ Distributions = "0.25"
Flux = "0.13, 0.14"
ImageCore = "0.9, 0.10"
ImageTransformations = "0.9, 0.10"
PrettyTables = "1, 2"
MacroTools = "0.5"
Tullio = "0.3"
Zygote = "0.6"
julia = "1.6"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using HTTP, FileIO, ImageMagick # load image from URL

# Load model
model = VGG(16, pretrain=true).layers
model = strip_softmax(flatten_chain(model))
model = strip_softmax(model.layers)

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
Expand Down
10 changes: 6 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ using ExplainableAI: lrp!, modify_layer
on_CI = haskey(ENV, "GITHUB_ACTIONS")

include("../test/vgg11.jl")
vgg11 = VGG11(; pretrain=false)
model = flatten_model(strip_softmax(vgg11.layers))
model = VGG11(; pretrain=false)
model = strip_softmax(model.layers)

T = Float32
img = rand(MersenneTwister(123), T, (224, 224, 3, 1))
Expand Down Expand Up @@ -41,6 +41,7 @@ in_dense = 500
out_dense = 100
aₖ = randn(T, insize)

#! format: off
layers = Dict(
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
Expand All @@ -57,12 +58,13 @@ rules = Dict(
)
layernames = String.(keys(layers))
rulenames = String.(keys(rules))
#! format: on

SUITE["modify layer"] = BenchmarkGroup(rulenames)
SUITE["apply rule"] = BenchmarkGroup(rulenames)
for rname in rulenames
SUITE["modify layer"][rname] = BenchmarkGroup(layernames)
SUITE["apply rule"][rname] = BenchmarkGroup(layernames)
SUITE["apply rule"][rname] = BenchmarkGroup(layernames)
end

for (lname, (layer, aₖ)) in layers
Expand All @@ -72,7 +74,7 @@ for (lname, (layer, aₖ)) in layers
modified_layer = modify_layer(rule, layer)
SUITE["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer))
SUITE["apply rule"][rname][lname] = @benchmarkable lrp!(
$(Rₖ), $(rule), $(modified_layer), $(aₖ), $(Rₖ₊₁)
$(Rₖ), $(rule), $(layer), $(modified_layer), $(aₖ), $(Rₖ₊₁)
)
end
end
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ RangeTypeRule
FirstLayerTypeRule
LastLayerTypeRule
FirstNTypeRule
LastNTypeRule
```

### [Default composites](@id default_composite_api)
Expand Down
26 changes: 11 additions & 15 deletions docs/src/literate/advanced_lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ analyzer = LRP(model, rules)
#
heatmap(input, analyzer)

# Since some Flux Chains contain other Flux Chains, ExplainableAI provides
# a utility function called [`flatten_model`](@ref).
#
#md # !!! warning "Flattening models"
#md # Not all models can be flattened, e.g. those using
#md # `Parallel` and `SkipConnection` layers.

# ### Custom composites
# Instead of manually defining a list of rules, we can also use a [`Composite`](@ref).
# A composite contructs a list of LRP-rules by sequentially applying
Expand Down Expand Up @@ -85,7 +78,6 @@ heatmap(input, analyzer)
# * [`FirstLayerTypeRule`](@ref) for a `TypeRule` on the first layer of a model
# * [`LastLayerTypeRule`](@ref) for a `TypeRule` on the last layer
# * [`FirstNTypeRule`](@ref) for a `TypeRule` on the first `n` layers
# * [`LastNTypeRule`](@ref) for a `TypeRule` on the last `n` layers
#
# Primitives are called sequentially in the order the `Composite` was created with
# and overwrite rules specified by previous primitives.
Expand Down Expand Up @@ -282,7 +274,7 @@ analyzer = LRP(model)
#
# This is done by calling low level functions
# ```julia
# lrp!(Rₖ, rule, modified_layer, aₖ, Rₖ₊₁)
# lrp!(Rₖ, rule, layer, modified_layer, aₖ, Rₖ₊₁)
# Rₖ .= ...
# end
# ```
Expand Down Expand Up @@ -331,7 +323,10 @@ analyzer = LRP(model)
# For `lrp!`, we implement the previous four step computation using `Zygote.pullback` to
# compute ``c`` from the previous equation as a VJP, pulling back ``s=R/z``:
# ```julia
# function lrp!(Rₖ, rule, modified_layer, aₖ, Rₖ₊₁)
# function lrp!(Rₖ, rule, layer, modified_layer, aₖ, Rₖ₊₁)
# # Use modified_layer if available, otherwise layer
# layer = ifelse(isnothing(modified_layer), layer, modified_layer)
#
# ãₖ = modify_input(rule, aₖ)
# z, back = Zygote.pullback(modified_layer, ãₖ)
# s = Rₖ₊₁ ./ modify_denominator(rule, z)
Expand All @@ -349,7 +344,7 @@ analyzer = LRP(model)
# Reshaping layers don't affect attributions. We can therefore avoid the computational
# overhead of AD by writing a specialized implementation that simply reshapes back:
# ```julia
# function lrp!(Rₖ, rule, ::ReshapingLayer, aₖ, Rₖ₊₁)
# function lrp!(Rₖ, rule, _layer::ReshapingLayer, _modified_layer, aₖ, Rₖ₊₁)
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
# end
# ```
Expand All @@ -358,18 +353,19 @@ analyzer = LRP(model)
#
# We can even implement the generic rule as a specialized implementation for `Dense` layers:
# ```julia
# function lrp!(Rₖ, rule, layer::Dense, aₖ, Rₖ₊₁)
# function lrp!(Rₖ, rule, layer::Dense, modified_layer, aₖ, Rₖ₊₁)
# layer = ifelse(isnothing(modified_layer), layer, modified_layer)
# ãₖ = modify_input(rule, aₖ)
# z = modify_denominator(rule, modified_layer(ãₖ))
# @tullio Rₖ[j, b] = modified_layer.weight[i, j] * ãₖ[j, b] / z[i, b] * Rₖ₊₁[i, b]
# z = modify_denominator(rule, layer(ãₖ))
# @tullio Rₖ[j, b] = layer.weight[i, j] * ãₖ[j, b] / z[i, b] * Rₖ₊₁[i, b]
# end
# ```
#
# For maximum low-level control beyond `modify_input` and `modify_denominator`,
# you can also implement your own `lrp!` function and dispatch
# on individual rule types `MyRule` and layer types `MyLayer`:
# ```julia
# function lrp!(Rₖ, rule::MyRule, layer::MyLayer, aₖ, Rₖ₊₁)
# function lrp!(Rₖ, rule::MyRule, layer::MyLayer, _modified_layer, aₖ, Rₖ₊₁)
# Rₖ .= ...
# end
# ```
15 changes: 7 additions & 8 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
module ExplainableAI

using Base.Iterators
using LinearAlgebra
using MacroTools: @forward
using Distributions: Distribution, Sampleable, Normal
using Random: AbstractRNG, GLOBAL_RNG
using Flux
using Zygote
using Tullio
using Markdown

# Heatmapping:
using ImageCore
using ImageTransformations: imresize
using ColorSchemes

# Model checks:
using Markdown
using PrettyTables

include("compat.jl")
include("bibliography.jl")
include("neuron_selection.jl")
include("analyze_api.jl")
include("flux_types.jl")
include("flux_utils.jl")
include("flux_layer_utils.jl")
include("flux_chain_utils.jl")
include("utils.jl")
include("input_augmentation.jl")
include("gradient.jl")
Expand Down Expand Up @@ -54,7 +52,8 @@ export PassRule, ZBoxRule, ZPlusRule, AlphaBetaRule, GeneralizedGammaRule
export Composite, AbstractCompositePrimitive
export LayerRule, GlobalRule, RangeRule, FirstLayerRule, LastLayerRule
export GlobalTypeRule, RangeTypeRule, FirstLayerTypeRule, LastLayerTypeRule
export FirstNTypeRule, LastNTypeRule
export FirstNTypeRule
export lrp_rules
# Default composites
export EpsilonGammaBox, EpsilonPlus, EpsilonAlpha2Beta1, EpsilonPlusFlat
export EpsilonAlpha2Beta1Flat
Expand All @@ -65,6 +64,6 @@ export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer
export heatmap

# utils
export strip_softmax, flatten_model, check_model, flatten_chain, canonize
export strip_softmax, flatten_model, canonize
export preprocess_imagenet
end # module
1 change: 1 addition & 0 deletions src/analyze_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@ struct Explanation{A,O,I,L}
output::O
neuron_selection::I
analyzer::Symbol
# TODO: turn into field extras of type Union{Nothing, Dict}
layerwise_relevances::L
end
Loading

0 comments on commit 6e12cd2

Please sign in to comment.