From 6e48268a59aa7cfa8dc7eb4c1c31ae3c65320f78 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Mon, 30 Sep 2024 17:23:01 +0200 Subject: [PATCH] Set up tests for dtmap and dtforeach, reorganize test folder a bit --- src/utility/util.jl | 35 +++++++++++++++++------ test/runtests.jl | 17 +++++++---- test/{utility.jl => test_utils.jl} | 0 test/utility/diff_maps.jl | 9 ++++++ test/{ctmrg => utility}/svd_wrapper.jl | 0 test/{ctmrg => utility}/symmetrization.jl | 0 6 files changed, 46 insertions(+), 15 deletions(-) rename test/{utility.jl => test_utils.jl} (100%) create mode 100644 test/utility/diff_maps.jl rename test/{ctmrg => utility}/svd_wrapper.jl (100%) rename test/{ctmrg => utility}/symmetrization.jl (100%) diff --git a/src/utility/util.jl b/src/utility/util.jl index 6fe63bc6..47f0c44c 100644 --- a/src/utility/util.jl +++ b/src/utility/util.jl @@ -167,28 +167,45 @@ dtforeach(args...; kwargs...) = tforeach(args...; kwargs...) # Follows the `map` rrule from ChainRules.jl but specified for the case of one AbstractArray that is being mapped # https://github.com/JuliaDiff/ChainRules.jl/blob/e245d50a1ae56ce46fc8c1f0fe9b925964f1146e/src/rulesets/Base/base.jl#L243 function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, - ::typeof(dtmap), - f::F, - A::AbstractArray; - kwargs..., -) where {F} + config::RuleConfig{>:HasReverseMode}, ::typeof(dtmap), f, A::AbstractArray; kwargs... +) el_rrules = tmap(A; kwargs...) do a rrule_via_ad(config, f, a) end y = map(first, el_rrules) - function map_pullback(dy_raw) + function dtmap_pullback(dy_raw) + dy = unthunk(dy_raw) + backevals = tmap(CartesianIndices(A); kwargs...) do idx + last(el_rrules[idx])(dy[idx]) + end + df = ProjectTo(f)(sum(first, backevals)) + dA = tmap(CartesianIndices(A); kwargs...) do idx + ProjectTo(A[idx])(last(backevals[idx])) + end + return (NoTangent(), df, dA) + end + return y, dtmap_pullback +end + +# TODO: fix this +function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, ::typeof(dtforeach), f, A::AbstractArray; kwargs... +) + el_rrules = tmap(A; kwargs...) do a + rrule_via_ad(config, f, a) + end + function dtforeach_pullback(dy_raw) dy = unthunk(dy_raw) backevals = tmap(CartesianIndices(A); kwargs...) do idx last(el_rrules[idx])(dy[idx]) end df = ProjectTo(f)(sum(first, backevals)) - dA = map(CartesianIndices(A)) do idx + dA = tmap(CartesianIndices(A); kwargs...) do idx ProjectTo(A[idx])(last(backevals[idx])) end return (NoTangent(), df, dA) end - return y, map_pullback + return nothing, dtforeach_pullback end """ diff --git a/test/runtests.jl b/test/runtests.jl index 93bafe43..c80f8283 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,9 +20,6 @@ end @time @safetestset "Unit cell" begin include("ctmrg/unitcell.jl") end - @time @safetestset "SVD wrapper" begin - include("ctmrg/svd_wrapper.jl") - end @time @safetestset ":fixed CTMRG iteration scheme" begin include("ctmrg/fixed_iterscheme.jl") end @@ -32,15 +29,23 @@ end @time @safetestset "CTMRG schemes" begin include("ctmrg/ctmrgschemes.jl") end - @time @safetestset "CTMRG schemes" begin - include("ctmrg/symmetrization.jl") - end end if GROUP == "ALL" || GROUP == "MPS" @time @safetestset "VUMPS" begin include("boundarymps/vumps.jl") end end + if GROUP == "ALL" || GROUP == "UTILITY" + @time @safetestset "SVD wrapper" begin + include("utility/svd_wrapper.jl") + end + @time @safetestset "Symmetrization" begin + include("utility/symmetrization.jl") + end + @time @safetestset "Differentiable tmap and tforeach" begin + include("utility/diff_maps.jl") + end + end if GROUP == "ALL" || GROUP == "EXAMPLES" @time @safetestset "Transverse Field Ising model" begin include("tf_ising.jl") diff --git a/test/utility.jl b/test/test_utils.jl similarity index 100% rename from test/utility.jl rename to test/test_utils.jl diff --git a/test/utility/diff_maps.jl b/test/utility/diff_maps.jl new file mode 100644 index 00000000..ac45af37 --- /dev/null +++ b/test/utility/diff_maps.jl @@ -0,0 +1,9 @@ +using ChainRulesTestUtils +using PEPSKit: dtmap, dtforeach + +# Can the rrule of dtmap be made inferable? (if check_inferred=true, tests error at the moment) +@testset "Differentiable tmap" begin + test_rrule(dtmap, x -> x^3, randn(5, 5); check_inferred=false) +end + +@testset "Differentiable tforeach" begin end diff --git a/test/ctmrg/svd_wrapper.jl b/test/utility/svd_wrapper.jl similarity index 100% rename from test/ctmrg/svd_wrapper.jl rename to test/utility/svd_wrapper.jl diff --git a/test/ctmrg/symmetrization.jl b/test/utility/symmetrization.jl similarity index 100% rename from test/ctmrg/symmetrization.jl rename to test/utility/symmetrization.jl