Skip to content

Commit

Permalink
build based on 6e12cd2
Browse files Browse the repository at this point in the history
  • Loading branch information
Documenter.jl committed Aug 21, 2023
1 parent bc0ce1f commit 8fc717f
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 101 deletions.
50 changes: 28 additions & 22 deletions dev/api/index.html

Large diffs are not rendered by default.

39 changes: 17 additions & 22 deletions dev/generated/advanced_lrp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) => \u001b[33mZBoxRule{Float32}(0.0f0, 1.0f0)\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Conv((5, 5), 6 => 16, relu) => \u001b[33mGammaRule{Float32}(0.25f0)\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Flux.flatten => \u001b[33mZeroRule()\u001b[39m,\n Dense(256 => 120, relu) => \u001b[33mZeroRule()\u001b[39m,\n Dense(120 => 84, relu) => \u001b[33mZeroRule()\u001b[39m,\n Dense(84 => 10) => \u001b[33mZeroRule()\u001b[39m,\n)\n"
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) \u001b[90m => \u001b[39m\u001b[33mZBoxRule{Float32}(0.0f0, 1.0f0)\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Conv((5, 5), 6 => 16, relu)\u001b[90m => \u001b[39m\u001b[33mGammaRule{Float32}(0.25f0)\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Flux.flatten \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(256 => 120, relu) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(120 => 84, relu) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(84 => 10) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n)\n"
},
"metadata": {},
"execution_count": 3
Expand Down Expand Up @@ -908,14 +908,6 @@
"metadata": {},
"execution_count": 4
},
{
"cell_type": "markdown",
"source": [
"Since some Flux Chains contain other Flux Chains, ExplainableAI provides\n",
"a utility function called `flatten_model`."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
Expand Down Expand Up @@ -965,7 +957,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) => \u001b[33mZBoxRule{Float32}(0.0f0, 1.0f0)\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Conv((5, 5), 6 => 16, relu) => \u001b[33mGammaRule{Float32}(0.25f0)\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Flux.flatten => \u001b[33mZeroRule()\u001b[39m,\n Dense(256 => 120, relu) => \u001b[33mZeroRule()\u001b[39m,\n Dense(120 => 84, relu) => \u001b[33mZeroRule()\u001b[39m,\n Dense(84 => 10) => \u001b[33mZeroRule()\u001b[39m,\n)\n"
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) \u001b[90m => \u001b[39m\u001b[33mZBoxRule{Float32}(0.0f0, 1.0f0)\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Conv((5, 5), 6 => 16, relu)\u001b[90m => \u001b[39m\u001b[33mGammaRule{Float32}(0.25f0)\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Flux.flatten \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(256 => 120, relu) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(120 => 84, relu) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(84 => 10) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n)\n"
},
"metadata": {},
"execution_count": 6
Expand Down Expand Up @@ -1817,7 +1809,6 @@
"* `FirstLayerTypeRule` for a `TypeRule` on the first layer of a model\n",
"* `LastLayerTypeRule` for a `TypeRule` on the last layer\n",
"* `FirstNTypeRule` for a `TypeRule` on the first `n` layers\n",
"* `LastNTypeRule` for a `TypeRule` on the last `n` layers\n",
"\n",
"Primitives are called sequentially in the order the `Composite` was created with\n",
"and overwrite rules specified by previous primitives."
Expand All @@ -1838,7 +1829,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "Composite(\n GlobalTypeRule( \u001b[90m# on all layers\u001b[39m\n \u001b[94mUnion{Flux.Conv, Flux.ConvTranspose, Flux.CrossCor}\u001b[39m => \u001b[33mZPlusRule()\u001b[39m,\n \u001b[94mFlux.Dense\u001b[39m => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n \u001b[94mUnion{typeof(NNlib.dropout), Flux.AlphaDropout, Flux.Dropout}\u001b[39m => \u001b[33mPassRule()\u001b[39m,\n \u001b[94mUnion{typeof(Flux.flatten), typeof(MLUtils.flatten)}\u001b[39m => \u001b[33mPassRule()\u001b[39m,\n ),\n FirstLayerTypeRule( \u001b[90m# on first layer\u001b[39m\n \u001b[94mUnion{Flux.Conv, Flux.ConvTranspose, Flux.CrossCor}\u001b[39m => \u001b[33mFlatRule()\u001b[39m,\n \u001b[94mFlux.Dense\u001b[39m => \u001b[33mFlatRule()\u001b[39m,\n ),\n)\n"
"text/plain": "Composite(\n GlobalTypeRule( \u001b[90m# on all layers\u001b[39m\n \u001b[94mUnion{Flux.Conv, Flux.ConvTranspose, Flux.CrossCor}\u001b[39m => \u001b[33mZPlusRule()\u001b[39m,\n \u001b[94mFlux.Dense\u001b[39m => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n \u001b[94mUnion{typeof(NNlib.dropout), Flux.AlphaDropout, Flux.Dropout}\u001b[39m => \u001b[33mPassRule()\u001b[39m,\n \u001b[94mUnion{typeof(Flux.flatten), typeof(MLUtils.flatten)}\u001b[39m => \u001b[33mPassRule()\u001b[39m,\n \u001b[94mtypeof(identity)\u001b[39m => \u001b[33mPassRule()\u001b[39m,\n ),\n FirstLayerTypeRule( \u001b[90m# on first layer\u001b[39m\n \u001b[94mUnion{Flux.Conv, Flux.ConvTranspose, Flux.CrossCor}\u001b[39m => \u001b[33mFlatRule()\u001b[39m,\n \u001b[94mFlux.Dense\u001b[39m => \u001b[33mFlatRule()\u001b[39m,\n ),\n)\n"
},
"metadata": {},
"execution_count": 8
Expand All @@ -1856,7 +1847,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) => \u001b[33mFlatRule()\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mZeroRule()\u001b[39m,\n Conv((5, 5), 6 => 16, relu) => \u001b[33mZPlusRule()\u001b[39m,\n MaxPool((2, 2)) => \u001b[33mZeroRule()\u001b[39m,\n Flux.flatten => \u001b[33mPassRule()\u001b[39m,\n Dense(256 => 120, relu) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Dense(120 => 84, relu) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Dense(84 => 10) => \u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n)\n"
"text/plain": "LRP(\n Conv((5, 5), 1 => 6, relu) \u001b[90m => \u001b[39m\u001b[33mFlatRule()\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Conv((5, 5), 6 => 16, relu)\u001b[90m => \u001b[39m\u001b[33mZPlusRule()\u001b[39m,\n MaxPool((2, 2)) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Flux.flatten \u001b[90m => \u001b[39m\u001b[33mPassRule()\u001b[39m,\n Dense(256 => 120, relu) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Dense(120 => 84, relu) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n Dense(84 => 10) \u001b[90m => \u001b[39m\u001b[33mEpsilonRule{Float32}(1.0f-6)\u001b[39m,\n)\n"
},
"metadata": {},
"execution_count": 9
Expand Down Expand Up @@ -1902,7 +1893,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "modify_parameters (generic function with 7 methods)"
"text/plain": "modify_parameters (generic function with 6 methods)"
},
"metadata": {},
"execution_count": 11
Expand Down Expand Up @@ -3796,7 +3787,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "LRP(\n Flux.flatten => \u001b[33mZeroRule()\u001b[39m,\n Dense(784 => 100, myrelu) => \u001b[33mZeroRule()\u001b[39m,\n Dense(100 => 10) => \u001b[33mZeroRule()\u001b[39m,\n)\n"
"text/plain": "LRP(\n Flux.flatten \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(784 => 100, myrelu)\u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n Dense(100 => 10) \u001b[90m => \u001b[39m\u001b[33mZeroRule()\u001b[39m,\n)\n"
},
"metadata": {},
"execution_count": 19
Expand Down Expand Up @@ -3824,7 +3815,7 @@
"\n",
"This is done by calling low level functions\n",
"```julia\n",
"lrp!(Rₖ, rule, modified_layer, aₖ, Rₖ₊₁)\n",
"lrp!(Rₖ, rule, layer, modified_layer, aₖ, Rₖ₊₁)\n",
" Rₖ .= ...\n",
"end\n",
"```\n",
Expand Down Expand Up @@ -3883,7 +3874,10 @@
"For `lrp!`, we implement the previous four step computation using `Zygote.pullback` to\n",
"compute $c$ from the previous equation as a VJP, pulling back $s=R/z$:\n",
"```julia\n",
"function lrp!(Rₖ, rule, modified_layer, aₖ, Rₖ₊₁)\n",
"function lrp!(Rₖ, rule, layer, modified_layer, aₖ, Rₖ₊₁)\n",
" # Use modified_layer if available, otherwise layer\n",
" layer = ifelse(isnothing(modified_layer), layer, modified_layer)\n",
"\n",
" ãₖ = modify_input(rule, aₖ)\n",
" z, back = Zygote.pullback(modified_layer, ãₖ)\n",
" s = Rₖ₊₁ ./ modify_denominator(rule, z)\n",
Expand All @@ -3906,7 +3900,7 @@
"Reshaping layers don't affect attributions. We can therefore avoid the computational\n",
"overhead of AD by writing a specialized implementation that simply reshapes back:\n",
"```julia\n",
"function lrp!(Rₖ, rule, ::ReshapingLayer, aₖ, Rₖ₊₁)\n",
"function lrp!(Rₖ, rule, _layer::ReshapingLayer, _modified_layer, aₖ, Rₖ₊₁)\n",
" Rₖ .= reshape(Rₖ₊₁, size(aₖ))\n",
"end\n",
"```\n",
Expand All @@ -3915,18 +3909,19 @@
"\n",
"We can even implement the generic rule as a specialized implementation for `Dense` layers:\n",
"```julia\n",
"function lrp!(Rₖ, rule, layer::Dense, aₖ, Rₖ₊₁)\n",
"function lrp!(Rₖ, rule, layer::Dense, modified_layer, aₖ, Rₖ₊₁)\n",
" layer = ifelse(isnothing(modified_layer), layer, modified_layer)\n",
" ãₖ = modify_input(rule, aₖ)\n",
" z = modify_denominator(rule, modified_layer(ãₖ))\n",
" @tullio Rₖ[j, b] = modified_layer.weight[i, j] * ãₖ[j, b] / z[i, b] * Rₖ₊₁[i, b]\n",
" z = modify_denominator(rule, layer(ãₖ))\n",
" @tullio Rₖ[j, b] = layer.weight[i, j] * ãₖ[j, b] / z[i, b] * Rₖ₊₁[i, b]\n",
"end\n",
"```\n",
"\n",
"For maximum low-level control beyond `modify_input` and `modify_denominator`,\n",
"you can also implement your own `lrp!` function and dispatch\n",
"on individual rule types `MyRule` and layer types `MyLayer`:\n",
"```julia\n",
"function lrp!(Rₖ, rule::MyRule, layer::MyLayer, aₖ, Rₖ₊₁)\n",
"function lrp!(Rₖ, rule::MyRule, layer::MyLayer, _modified_layer, aₖ, Rₖ₊₁)\n",
" Rₖ .= ...\n",
"end\n",
"```"
Expand Down
Loading

0 comments on commit 8fc717f

Please sign in to comment.