diff --git a/Project.toml b/Project.toml index b28d12e..46ccaab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RelevancePropagation" uuid = "0be6dd02-ae9e-43eb-b318-c6e81d6890d8" authors = ["Adrian Hill "] -version = "2.0.3-DEV" +version = "3.0.0-DEV" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -14,12 +14,12 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Flux = "0.13, 0.14" +Flux = "0.14" MacroTools = "0.5" Markdown = "1" Random = "1" Reexport = "1" Statistics = "1" -XAIBase = "3" +XAIBase = "4" Zygote = "0.6" -julia = "1.6" +julia = "1.10" diff --git a/README.md b/README.md index e33953f..81434fd 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ This package is part of the [Julia-XAI ecosystem](https://github.com/Julia-XAI) [ExplainableAI.jl](https://github.com/Julia-XAI/ExplainableAI.jl). ## Installation -This package supports Julia ≥1.6. To install it, open the Julia REPL and run +This package supports Julia ≥1.10. To install it, open the Julia REPL and run ```julia-repl julia> ]add RelevancePropagation ``` diff --git a/src/RelevancePropagation.jl b/src/RelevancePropagation.jl index b914dc5..4930612 100644 --- a/src/RelevancePropagation.jl +++ b/src/RelevancePropagation.jl @@ -2,6 +2,7 @@ module RelevancePropagation using Reexport @reexport using XAIBase +import XAIBase: call_analyzer using XAIBase: AbstractFeatureSelector, number_of_features using Base.Iterators @@ -12,7 +13,6 @@ using Zygote using Markdown using Statistics: mean, std -include("compat.jl") include("bibliography.jl") include("layer_types.jl") include("layer_utils.jl") diff --git a/src/compat.jl b/src/compat.jl deleted file mode 100644 index 368cdca..0000000 --- a/src/compat.jl +++ /dev/null @@ -1,6 +0,0 @@ -if VERSION < v"1.8.0-DEV.1494" # 98e60ffb11ee431e462b092b48a31a1204bd263d - export allequal - allequal(itr) = isempty(itr) ? true : all(isequal(first(itr)), itr) - allequal(c::Union{AbstractSet,AbstractDict}) = length(c) <= 1 - allequal(r::AbstractRange) = iszero(step(r)) || length(r) <= 1 -end diff --git a/src/crp.jl b/src/crp.jl index 9d57790..6964265 100644 --- a/src/crp.jl +++ b/src/crp.jl @@ -32,7 +32,9 @@ end # Call to CRP analyzer # #======================# -function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where {T,N} +function call_analyzer( + input::AbstractArray{T,N}, crp::CRP, ns::AbstractOutputSelector +) where {T,N} rules = crp.lrp.rules layers = crp.lrp.model.layers modified_layers = crp.lrp.modified_layers @@ -88,5 +90,5 @@ function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where end end end - return Explanation(R_return, last(as), ns(last(as)), :CRP, :attribution, nothing) + return Explanation(R_return, input, last(as), ns(last(as)), :CRP, :attribution, nothing) end diff --git a/src/lrp.jl b/src/lrp.jl index 758264d..f3d0be1 100644 --- a/src/lrp.jl +++ b/src/lrp.jl @@ -55,8 +55,8 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa # Call to the LRP analyzer # #==========================# -function (lrp::LRP)( - input::AbstractArray, ns::AbstractOutputSelector; layerwise_relevances=false +function call_analyzer( + input::AbstractArray, lrp::LRP, ns::AbstractOutputSelector; layerwise_relevances=false ) as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k Rs = similar.(as) # allocate relevances Rᵏ for all layers k @@ -64,7 +64,7 @@ function (lrp::LRP)( lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers) extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing - return Explanation(first(Rs), last(as), ns(last(as)), :LRP, :attribution, extras) + return Explanation(first(Rs), input, last(as), ns(last(as)), :LRP, :attribution, extras) end get_activations(model, input) = (input, Flux.activations(model, input)...) diff --git a/test/test_batches.jl b/test/test_batches.jl index 74c063a..09c92dc 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -30,20 +30,12 @@ ANALYZERS = Dict( for (name, method) in ANALYZERS @testset "$name" begin - # Using `add_batch_dim=true` should result in same explanation - # as input reshaped to have a batch dimension - analyzer = method(model) - expl1_no_bd = analyzer(input1_no_bd; add_batch_dim=true) - analyzer = method(model) - expl1_bd = analyzer(input1_bd) - @test expl1_bd.val ≈ expl1_no_bd.val - # Analyzing a batch should have the same result # as analyzing inputs in batch individually analyzer = method(model) expl2_bd = analyzer(input2_bd) analyzer = method(model) expl_batch = analyzer(input_batch) - @test expl1_bd.val ≈ expl_batch.val[:, 1] + @test expl2_bd.val ≈ expl_batch.val[:, 2] end end