Skip to content

Commit

Permalink
Set up tests for dtmap and dtforeach, reorganize test folder a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrehmer committed Sep 30, 2024
1 parent 0d7b807 commit 6e48268
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 15 deletions.
35 changes: 26 additions & 9 deletions src/utility/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,28 +167,45 @@ dtforeach(args...; kwargs...) = tforeach(args...; kwargs...)
# Follows the `map` rrule from ChainRules.jl but specified for the case of one AbstractArray that is being mapped
# https://github.com/JuliaDiff/ChainRules.jl/blob/e245d50a1ae56ce46fc8c1f0fe9b925964f1146e/src/rulesets/Base/base.jl#L243
function ChainRulesCore.rrule(
config::RuleConfig{>:HasReverseMode},
::typeof(dtmap),
f::F,
A::AbstractArray;
kwargs...,
) where {F}
config::RuleConfig{>:HasReverseMode}, ::typeof(dtmap), f, A::AbstractArray; kwargs...
)
el_rrules = tmap(A; kwargs...) do a
rrule_via_ad(config, f, a)
end
y = map(first, el_rrules)
function map_pullback(dy_raw)
function dtmap_pullback(dy_raw)
dy = unthunk(dy_raw)
backevals = tmap(CartesianIndices(A); kwargs...) do idx
last(el_rrules[idx])(dy[idx])
end
df = ProjectTo(f)(sum(first, backevals))
dA = tmap(CartesianIndices(A); kwargs...) do idx
ProjectTo(A[idx])(last(backevals[idx]))
end
return (NoTangent(), df, dA)
end
return y, dtmap_pullback
end

# TODO: fix this
function ChainRulesCore.rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(dtforeach), f, A::AbstractArray; kwargs...
)
el_rrules = tmap(A; kwargs...) do a
rrule_via_ad(config, f, a)
end
function dtforeach_pullback(dy_raw)
dy = unthunk(dy_raw)
backevals = tmap(CartesianIndices(A); kwargs...) do idx
last(el_rrules[idx])(dy[idx])
end
df = ProjectTo(f)(sum(first, backevals))
dA = map(CartesianIndices(A)) do idx
dA = tmap(CartesianIndices(A); kwargs...) do idx
ProjectTo(A[idx])(last(backevals[idx]))
end
return (NoTangent(), df, dA)
end
return y, map_pullback
return nothing, dtforeach_pullback
end

"""
Expand Down
17 changes: 11 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ end
@time @safetestset "Unit cell" begin
include("ctmrg/unitcell.jl")
end
@time @safetestset "SVD wrapper" begin
include("ctmrg/svd_wrapper.jl")
end
@time @safetestset ":fixed CTMRG iteration scheme" begin
include("ctmrg/fixed_iterscheme.jl")
end
Expand All @@ -32,15 +29,23 @@ end
@time @safetestset "CTMRG schemes" begin
include("ctmrg/ctmrgschemes.jl")
end
@time @safetestset "CTMRG schemes" begin
include("ctmrg/symmetrization.jl")
end
end
if GROUP == "ALL" || GROUP == "MPS"
@time @safetestset "VUMPS" begin
include("boundarymps/vumps.jl")
end
end
if GROUP == "ALL" || GROUP == "UTILITY"
@time @safetestset "SVD wrapper" begin
include("utility/svd_wrapper.jl")
end
@time @safetestset "Symmetrization" begin
include("utility/symmetrization.jl")
end
@time @safetestset "Differentiable tmap and tforeach" begin
include("utility/diff_maps.jl")
end
end
if GROUP == "ALL" || GROUP == "EXAMPLES"
@time @safetestset "Transverse Field Ising model" begin
include("tf_ising.jl")
Expand Down
File renamed without changes.
9 changes: 9 additions & 0 deletions test/utility/diff_maps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using ChainRulesTestUtils
using PEPSKit: dtmap, dtforeach

# Can the rrule of dtmap be made inferable? (if check_inferred=true, tests error at the moment)
@testset "Differentiable tmap" begin
test_rrule(dtmap, x -> x^3, randn(5, 5); check_inferred=false)
end

@testset "Differentiable tforeach" begin end
File renamed without changes.
File renamed without changes.

0 comments on commit 6e48268

Please sign in to comment.