Skip to content

Commit

Permalink
Modularize package tests (#17)
Browse files Browse the repository at this point in the history
* Modularize tests

* Fix type inferrability

* Fix Flux GPU URL
  • Loading branch information
adrhill authored Jun 24, 2024
1 parent 195da29 commit a86c962
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 27 deletions.
2 changes: 1 addition & 1 deletion docs/src/literate/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ expl.extras.layerwise_relevances
# ## Performance tips
# ### Using LRP with a GPU
# All LRP analyzers support GPU backends,
# building on top of [Flux.jl's GPU support](https://fluxml.ai/Flux.jl/stable/gpu/).
# building on top of [Flux.jl's GPU support](https://fluxml.ai/Flux.jl/stable/guide/gpu/).
# Using a GPU only requires moving the input array and model weights to the GPU.
#
# For example, using [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl):
Expand Down
1 change: 1 addition & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function lrp!(Rᵏ, rule::AbstractLRPRule, layer, modified_layer, aᵏ, Rᵏ⁺
s = Rᵏ⁺¹ ./ modify_denominator(rule, z)
c = only(back(s))
Rᵏ .= ãᵏ .* c
return Rᵏ
end

#===================================#
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

31 changes: 15 additions & 16 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
using Test
using ReferenceTests
using Aqua
using JuliaFormatter
using Random

using RelevancePropagation
using Flux
import Flux: Scale

pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
using Test
using JuliaFormatter
using Aqua

@testset "RelevancePropagation.jl" begin
@testset "Aqua.jl" begin
@info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies."
Aqua.test_all(RelevancePropagation; ambiguities=false)
end
@testset "JuliaFormatter.jl" begin
@info "Running JuliaFormatter's code formatting tests."
@test format(RelevancePropagation; verbose=false, overwrite=false)
if VERSION >= v"1.10"
@testset "Code formatting" begin
@info "- Testing code formatting with JuliaFormatter..."
@test JuliaFormatter.format(
RelevancePropagation; verbose=false, overwrite=false
)
end
@testset "Aqua.jl" begin
@info "- Running Aqua.jl tests. These might print warnings from dependencies..."
Aqua.test_all(RelevancePropagation; ambiguities=false)
end
end

@testset "Utilities" begin
@info "Testing utilities..."
include("test_utils.jl")
Expand Down
8 changes: 6 additions & 2 deletions test/test_batches.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using Flux
using RelevancePropagation
using Random
using Test

using Flux
using Random: rand, MersenneTwister

pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)

## Test `fuse_batchnorm` on Dense and Conv layers
ins = 20
Expand Down
6 changes: 5 additions & 1 deletion test/test_canonize.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using RelevancePropagation
using Test

using Flux
using Flux: flatten, Scale
using RelevancePropagation
using RelevancePropagation: canonize_fuse
using Random

pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)

batchsize = 50

##=====================================#
Expand Down
3 changes: 3 additions & 0 deletions test/test_chain_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using RelevancePropagation
using Test

using RelevancePropagation: ChainTuple, ParallelTuple, SkipConnectionTuple
using RelevancePropagation: ModelIndex, chainmap, chainindices, chainzip
using RelevancePropagation: activation_fn
Expand Down
4 changes: 4 additions & 0 deletions test/test_checks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using RelevancePropagation
using Test
using ReferenceTests

using RelevancePropagation: check_lrp_compat, print_lrp_model_check
using Suppressor

err = ErrorException("Unknown layer or activation function found in model")

# Flux layers
Expand Down
6 changes: 6 additions & 0 deletions test/test_cnn.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
using RelevancePropagation
using Test
using ReferenceTests

using Flux
using JLD2
using Random: rand, MersenneTwister

const LRP_ANALYZERS = Dict(
"LRPZero" => LRP,
"LRPZero_COC" => m -> LRP(m; flatten=false), # chain of chains
"LRPEpsilonAlpha2Beta1Flat" => m -> LRP(m, EpsilonAlpha2Beta1Flat()),
)

pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)

input_size = (32, 32, 3, 1)
input = pseudorand(input_size)

Expand Down
7 changes: 6 additions & 1 deletion test/test_composite.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using RelevancePropagation
using Test
using ReferenceTests

using NNlib
using Flux
using Flux: flatten, Scale
using Metalhead
using Flux, NNlib

model = VGG(11; pretrain=false).layers
model_flat = flatten_model(model)
Expand Down
5 changes: 5 additions & 0 deletions test/test_crp.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
using RelevancePropagation
using Test

using Flux

@testset "CRP analytic" begin
W1 = [1.0 3.0; 4.0 2.0]
b1 = [0.0, 1.0]
Expand Down
7 changes: 5 additions & 2 deletions test/test_rules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using RelevancePropagation
using Test
using ReferenceTests

using RelevancePropagation: lrp!, modify_input, modify_denominator, is_compatible
using RelevancePropagation: modify_parameters, modify_weight, modify_bias, modify_layer
using RelevancePropagation: stabilize_denom
using Flux
using Flux: flatten, Scale
using LinearAlgebra: I
using ReferenceTests
using Random
using Random: randn, MersenneTwister

# Fixed pseudo-random numbers
T = Float32
Expand Down
12 changes: 9 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
using Flux
using Flux: flatten
using RelevancePropagation
using Test

using RelevancePropagation: activation_fn, copy_layer, flatten_model
using RelevancePropagation: has_output_softmax, check_output_softmax
using RelevancePropagation: stabilize_denom, drop_batch_index, masked_copy
using Random

using Flux
using Flux: flatten, Scale
using Random: rand, MersenneTwister

pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)

# Test `activation_fn`
@test activation_fn(Dense(5, 2, gelu)) == gelu
Expand Down

0 comments on commit a86c962

Please sign in to comment.