diff --git a/Project.toml b/Project.toml index db12c9e..10d302f 100644 --- a/Project.toml +++ b/Project.toml @@ -11,12 +11,17 @@ Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" [extensions] +QrochetChainRulesCoreExt = "ChainRulesCore" +QrochetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"] QrochetQuacExt = "Quac" [compat] +ChainRulesCore = "1.0" Muscle = "0.1" Quac = "0.3" Tenet = "0.5" diff --git a/ext/QrochetChainRulesCoreExt.jl b/ext/QrochetChainRulesCoreExt.jl new file mode 100644 index 0000000..b011d76 --- /dev/null +++ b/ext/QrochetChainRulesCoreExt.jl @@ -0,0 +1,27 @@ +module QrochetChainRulesCoreExt + +using Qrochet +using ChainRulesCore +using ChainRulesCore: AbstractTangent +using Tenet + +@non_differentiable Qrochet.currindex() +@non_differentiable Qrochet.nextindex() + +# WARN type-piracy +@non_differentiable Base.setdiff(::Vector{Symbol}, ::Base.ValueIterator) + +ChainRulesCore.ProjectTo(qtn::Quantum) = ProjectTo{Quantum}(; tn = ProjectTo(qtn.tn)) +(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ.tn), Δ.sites) + +function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) + y = Quantum(x, sites) + ẏ = Tangent{Quantum}(; tn = ẋ) + y, ẏ +end + +Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback + +end diff --git a/ext/QrochetChainRulesTestUtilsExt.jl b/ext/QrochetChainRulesTestUtilsExt.jl new file mode 100644 index 0000000..8d9c14f --- /dev/null +++ b/ext/QrochetChainRulesTestUtilsExt.jl @@ -0,0 +1,16 @@ +module QrochetChainRulesTestUtilsExt + +using Qrochet +using ChainRulesCore +using ChainRulesTestUtils +using Random + +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum) + return Tangent{Quantum}(; tn = rand_tangent(rng, x.tn), sites = NoTangent()) +end + +# WARN type-piracy +# NOTE used in `Quantum` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent() + +end diff --git a/test/Project.toml b/test/Project.toml index dfed1ae..ff50946 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" diff --git a/test/integration/ChainRulesCore_test.jl b/test/integration/ChainRulesCore_test.jl new file mode 100644 index 0000000..dfe68fe --- /dev/null +++ b/test/integration/ChainRulesCore_test.jl @@ -0,0 +1,10 @@ +@testset "ChainRulesCore" begin + using Qrochet + using Tenet + using ChainRulesTestUtils + + @testset "Quantum" begin + test_frule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + test_rrule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 021ef25..8a13796 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ end @testset "Integration tests" verbose = true begin include("integration/Quac_test.jl") + include("integration/ChainRulesCore_test.jl") end if haskey(ENV, "ENABLE_AQUA_TESTS")