Skip to content

Commit

Permalink
Add Concept Relevance Propagation (#146)
Browse files Browse the repository at this point in the history
* add `CRP` analyzer that wraps `LRP`

* add two concept selectors: `TopNConcepts` and `IndexedConcepts`

* add `process_batch` argument to `heatmap`
  • Loading branch information
adrhill authored Sep 29, 2023
1 parent cbb2906 commit 12b3118
Show file tree
Hide file tree
Showing 44 changed files with 528 additions and 187 deletions.
48 changes: 24 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using ImageInTerminal # show heatmap in terminal
# Load model
model = VGG(16, pretrain=true).layers
model = strip_softmax(model)
model = canonize(model)

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
Expand Down Expand Up @@ -87,30 +88,29 @@ Check out our talk at JuliaCon 2022 for a demonstration of the package.
## Methods
Currently, the following analyzers are implemented:

```
├── Gradient
├── InputTimesGradient
├── SmoothGrad
├── IntegratedGradients
└── LRP
├── Rules
│ ├── ZeroRule
│ ├── EpsilonRule
│ ├── GammaRule
│ ├── GeneralizedGammaRule
│ ├── WSquareRule
│ ├── FlatRule
│ ├── ZBoxRule
│ ├── ZPlusRule
│ ├── AlphaBetaRule
│ └── PassRule
└── Composite
├── EpsilonGammaBox
├── EpsilonPlus
├── EpsilonPlusFlat
├── EpsilonAlpha2Beta1
└── EpsilonAlpha2Beta1Flat
```
* `Gradient`
* `InputTimesGradient`
* `SmoothGrad`
* `IntegratedGradients`
* `LRP`
* Rules
* `ZeroRule`
* `EpsilonRule`
* `GammaRule`
* `GeneralizedGammaRule`
* `WSquareRule`
* `FlatRule`
* `ZBoxRule`
* `ZPlusRule`
* `AlphaBetaRule`
* `PassRule`
* Composites
* `EpsilonGammaBox`
* `EpsilonPlus`
* `EpsilonPlusFlat`
* `EpsilonAlpha2Beta1`
* `EpsilonAlpha2Beta1Flat`
* `CRP`

One of the design goals of ExplainableAI.jl is extensibility.
Custom [composites][docs-composites] are easily defined
Expand Down
5 changes: 3 additions & 2 deletions docs/src/literate/augmentations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ heatmap(input, analyzer)
analyzer = IntegratedGradients(model, 50)
heatmap(input, analyzer)

# To select a different reference input, pass it to the `analyze` or `heatmap` function
# To select a different reference input, pass it to the `analyze` function
# using the keyword argument `input_ref`.
# Note that this is an arbitrary example for the sake of demonstration.
matrix_of_ones = ones(Float32, size(input))

analyzer = InterpolationAugmentation(Gradient(model), 50)
heatmap(input, analyzer; input_ref=matrix_of_ones)
expl = analyzer(input; input_ref=matrix_of_ones)
heatmap(expl)

# Once again, `InterpolationAugmentation` can be combined with any analyzer type,
# for example [`LRP`](@ref):
Expand Down
7 changes: 7 additions & 0 deletions docs/src/lrp/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ LRP_CONFIG.supports_layer
LRP_CONFIG.supports_activation
```

# CRP
```@docs
CRP
TopNConcepts
IndexedConcepts
```

# Index
```@index
```
7 changes: 3 additions & 4 deletions ext/TullioLRPRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ import ExplainableAI: ZeroRule, EpsilonRule, GammaRule, WSquareRule

# Fast implementation for Dense layer using Tullio.jl's einsum notation:
for R in (ZeroRule, EpsilonRule, GammaRule)
@eval function lrp!(Rᵏ, rule::$R, layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
layer = isnothing(modified_layer) ? layer : modified_layer
@eval function lrp!(Rᵏ, rule::$R, _layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
ãᵏ = modify_input(rule, aᵏ)
z = modify_denominator(rule, layer(ãᵏ))
@tullio Rᵏ[j, b] = layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
z = modify_denominator(rule, modified_layer(ãᵏ))
@tullio Rᵏ[j, b] = modified_layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
end
end

Expand Down
4 changes: 4 additions & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ include("lrp/composite.jl")
include("lrp/lrp.jl")
include("lrp/show.jl")
include("lrp/composite_presets.jl") # uses lrp/show.jl
include("lrp/crp.jl")
include("heatmap.jl")
include("preprocessing.jl")
export analyze
Expand Down Expand Up @@ -61,6 +62,9 @@ export EpsilonAlpha2Beta1Flat
# Useful type unions
export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer, NormalizationLayer

# CRP
export CRP, TopNConcepts, IndexedConcepts

# heatmapping
export heatmap

Expand Down
64 changes: 35 additions & 29 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
# Analyzer => (colorscheme, reduce, rangescale)
:LRP => (ColorSchemes.seismic, :sum, :centered), # attribution
:CRP => (ColorSchemes.seismic, :sum, :centered), # attribution
:InputTimesGradient => (ColorSchemes.seismic, :sum, :centered), # attribution
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
)
Expand Down Expand Up @@ -36,70 +37,75 @@ See also [`analyze`](@ref).
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
will return an image instead of an Vector containing a single image.
**Note:** keyword arguments can't be used when calling `heatmap` with an analyzer.
- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true`
will apply the color channel reduction and normalization to the entire batch
instead of computing it individually for each sample. Defaults to `false`.
"""
function heatmap(
attr::AbstractArray{T,N};
val::AbstractArray{T,N};
cs::ColorScheme=ColorSchemes.seismic,
reduce::Symbol=:sum,
rangescale::Symbol=:centered,
permute::Bool=true,
unpack_singleton::Bool=true,
process_batch::Bool=false,
) where {T,N}
N != 4 && throw(
DomainError(
N,
"""heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
Please reshape your explanation to match this format if your model doesn't adhere to this convention.""",
ArgumentError(
"heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
Please reshape your explanation to match this format if your model doesn't adhere to this convention.",
),
)
if unpack_singleton && size(attr, 4) == 1
return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute)
if unpack_singleton && size(val, 4) == 1
return _heatmap(val[:, :, :, 1], cs, reduce, rangescale, permute)
end
if process_batch
hs = _heatmap(val, cs, reduce, rangescale, permute)
return [hs[:, :, i] for i in axes(hs, 3)]
end
return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4))
return [_heatmap(v, cs, reduce, rangescale, permute) for v in eachslice(val; dims=4)]
end

# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
function heatmap(expl::Explanation; kwargs...)
_cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
return heatmap(
expl.val;
reduce=get(kwargs, :reduce, _reduce),
rangescale=get(kwargs, :rangescale, _rangescale),
cs=get(kwargs, :cs, _cs),
permute=permute,
kwargs...,
)
end
# Analyze & heatmap in one go
function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
return heatmap(analyze(input, analyzer, args...; kwargs...))
expl = analyze(input, analyzer, args...)
return heatmap(expl; kwargs...)
end

# Lower level function that is mapped along batch dimension
function _heatmap(
attr::AbstractArray{T,3},
cs::ColorScheme,
reduce::Symbol,
rangescale::Symbol,
permute::Bool,
) where {T<:Real}
img = dropdims(_reduce(attr, reduce); dims=3)
permute && (img = permutedims(img))
# Lower level function that can be mapped along batch dimension
function _heatmap(val, cs::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool)
img = dropdims(reduce_color_channel(val, reduce); dims=3)
permute && (img = flip_wh(img))
return ColorSchemes.get(cs, img, rangescale)
end

flip_wh(img::AbstractArray{T,2}) where {T} = permutedims(img, (2, 1))
flip_wh(img::AbstractArray{T,3}) where {T} = permutedims(img, (2, 1, 3))

# Reduce explanations across color channels into a single scalar – assumes WHCN convention
function _reduce(attr::AbstractArray{T,3}, method::Symbol) where {T}
if size(attr, 3) == 1 # nothing to reduce
return attr
function reduce_color_channel(val::AbstractArray, method::Symbol)
init = zero(eltype(val))
if size(val, 3) == 1 # nothing to reduce
return val
elseif method == :sum
return reduce(+, attr; dims=3)
return reduce(+, val; dims=3)
elseif method == :maxabs
return reduce((c...) -> maximum(abs.(c)), attr; dims=3, init=zero(T))
return reduce((c...) -> maximum(abs.(c)), val; dims=3, init=init)
elseif method == :norm
return reduce((c...) -> sqrt(sum(c .^ 2)), attr; dims=3, init=zero(T))
return reduce((c...) -> sqrt(sum(c .^ 2)), val; dims=3, init=init)
end

throw(
ArgumentError(
"Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm",
Expand Down
Loading

0 comments on commit 12b3118

Please sign in to comment.