Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Add chain-rules for Quantum constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 5, 2024
1 parent c4bb805 commit a6a7dc6
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 1 deletion.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"

[extensions]
QrochetChainRulesCoreExt = "ChainRulesCore"
QrochetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
QrochetQuacExt = "Quac"

[compat]
ChainRulesCore = "1.0"
Muscle = "0.1"
Quac = "0.3"
Tenet = "0.5"
Expand Down
27 changes: 27 additions & 0 deletions ext/QrochetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module QrochetChainRulesCoreExt

using Qrochet
using ChainRulesCore
using ChainRulesCore: AbstractTangent
using Tenet

@non_differentiable Qrochet.currindex()
@non_differentiable Qrochet.nextindex()

# WARN type-piracy
@non_differentiable Base.setdiff(::Vector{Symbol}, ::Base.ValueIterator)

ChainRulesCore.ProjectTo(qtn::Quantum) = ProjectTo{Quantum}(; tn = ProjectTo(qtn.tn))
(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn.tn), Δ.sites)

function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
y = Quantum(x, sites)
= Tangent{Quantum}(; tn = ẋ)
y, ẏ
end

Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

end
16 changes: 16 additions & 0 deletions ext/QrochetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module QrochetChainRulesTestUtilsExt

using Qrochet
using ChainRulesCore
using ChainRulesTestUtils
using Random

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum)
return Tangent{Quantum}(; tn = rand_tangent(rng, x.tn), sites = NoTangent())
end

# WARN type-piracy
# NOTE used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent()

end
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
10 changes: 10 additions & 0 deletions test/integration/ChainRulesCore_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testset "ChainRulesCore" begin
using Qrochet
using Tenet
using ChainRulesTestUtils

@testset "Quantum" begin
test_frule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
test_rrule(Quantum, TensorNetwork([Tensor(fill(1.0, 2), [:i])]), Dict{Site,Symbol}(site"1" => :i))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ end

@testset "Integration tests" verbose = true begin
include("integration/Quac_test.jl")
include("integration/ChainRulesCore_test.jl")
end

if haskey(ENV, "ENABLE_AQUA_TESTS")
Expand Down

0 comments on commit a6a7dc6

Please sign in to comment.