Skip to content

Commit

Permalink
Update XAIBase to v4
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Oct 10, 2024
1 parent 9088127 commit 238c65b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/RelevancePropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module RelevancePropagation

using Reexport
@reexport using XAIBase
import XAIBase: call_analyzer

using XAIBase: AbstractFeatureSelector, number_of_features
using Base.Iterators
Expand Down
6 changes: 4 additions & 2 deletions src/crp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ end
# Call to CRP analyzer #
#======================#

function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where {T,N}
function call_analyzer(
input::AbstractArray{T,N}, crp::CRP, ns::AbstractOutputSelector
) where {T,N}
rules = crp.lrp.rules
layers = crp.lrp.model.layers
modified_layers = crp.lrp.modified_layers
Expand Down Expand Up @@ -88,5 +90,5 @@ function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where
end
end
end
return Explanation(R_return, last(as), ns(last(as)), :CRP, :attribution, nothing)
return Explanation(R_return, input, last(as), ns(last(as)), :CRP, :attribution, nothing)
end
6 changes: 3 additions & 3 deletions src/lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa
# Call to the LRP analyzer #
#==========================#

function (lrp::LRP)(
input::AbstractArray, ns::AbstractOutputSelector; layerwise_relevances=false
function call_analyzer(
input::AbstractArray, lrp::LRP, ns::AbstractOutputSelector; layerwise_relevances=false
)
as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N

lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers)
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
return Explanation(first(Rs), last(as), ns(last(as)), :LRP, :attribution, extras)
return Explanation(first(Rs), input, last(as), ns(last(as)), :LRP, :attribution, extras)
end

get_activations(model, input) = (input, Flux.activations(model, input)...)
Expand Down

0 comments on commit 238c65b

Please sign in to comment.