Skip to content

Commit

Permalink
Support Adapt conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 20, 2024
1 parent 005bc28 commit 470b4e2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand All @@ -26,6 +27,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"

[extensions]
TenetAdaptExt = "Adapt"
TenetChainRulesCoreExt = "ChainRulesCore"
TenetChainRulesExt = "ChainRules"
TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
Expand All @@ -35,6 +37,7 @@ TenetMakieExt = "Makie"

[compat]
AbstractTrees = "0.4"
Adapt = "4"
ChainRules = "1.0"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
Expand Down
10 changes: 10 additions & 0 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module TenetAdaptExt

using Tenet
using Adapt

Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x))

Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x)))

end

0 comments on commit 470b4e2

Please sign in to comment.