Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-threading using OhMyThreads and make it differentiable #70

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

pbrehmer
Copy link
Collaborator

@pbrehmer pbrehmer commented Sep 27, 2024

Here we'll replace the @fwdthreads macro with tmap and foreach calls. Additionally, we will code up reverse rules such that the backwards pass also runs in parallel.

@pbrehmer
Copy link
Collaborator Author

@lkdvos I could use some help on this one... I still find writing rrules very confusing sometimes, so could you maybe take a look at the rrule for dtforeach, once you have time? And while the dtmap rrule already seems to work, it might not be perfect yet :-)

Other than that we have to think about how to pass along the threading kwargs to the dtmap and dtforeach calls. One option is to store those kwargs inside the CTMRG struct but that feels a bit wrong. (Also there are some calls which wouldn't have access to a CTMRG instance.) Not sure what would be the best solution for that. Would global variables work where the user can set that somehow?

I'll also add some rrule tests in the end.

@lkdvos
Copy link
Member

lkdvos commented Sep 30, 2024

I think tforeach is going to be a bit hard, because it has no outputs, and is thus necessarily in-place... I have no clue how the zygote buffer magic works, so I can't really say I know how to deal with that either.

I should have some more time to think this through next week though!

For the global variables, I would maybe suggest ScopedVariables.jl instead, this is a little more flexible and shouldn't incur too much runtime costs

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Oct 1, 2024

I think tforeach is going to be a bit hard, because it has no outputs, and is thus necessarily in-place... I have no clue how the zygote buffer magic works, so I can't really say I know how to deal with that either.

How about we just stick to tmap then? The only slightly annoying thing is having to separate multiple return values at different indices but that should not incur too much overhead.

For the global variables, I would maybe suggest ScopedVariables.jl instead, this is a little more flexible and shouldn't incur too much runtime costs

Wasn't aware of ScopedValues.jl yet, that looks like a great solution. But I don't quite understand the necessity of a scoped value here since we never need to access the threading settings inside a multi-threaded map, right? In any case, we can probably just have a global scoped Dict with the threading settings that are passed on to the dtmap calls, and that can be mutated by some set_thread_settings function.

Anyways, I will give these things a go and then we can review next week, when you have time :)

Copy link
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should have elaborated a bit more on what I had in mind with the scoped values: I would keep the threading strategies in scoped values, such that they can always be accessed as if they are global values. (this prevents them from bloating all of our algorithms)
The benefit of having them as a scoped value however means that users could still change them by calling the peps function from within a scope with a modified scoped value, thus changing the scheduler.

src/utility/diffable_threads.jl Outdated Show resolved Hide resolved
backevals = tmap(CartesianIndices(A); kwargs...) do idx
last(el_rrules[idx])(dy[idx])
end
df = ProjectTo(f)(sum(first, backevals))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow it should be possible to do a tmapreduce and combine all of this, but I'll try and look into it a bit more

src/utility/diffable_threads.jl Outdated Show resolved Hide resolved
@pbrehmer
Copy link
Collaborator Author

I was trying to fix the Zygote error but with no luck, I really don't know how to handle the NoTangents (dA is an Array{NoTangent,3}) as they are being converted to Nothing. I'll copy the error here for future reference:

ERROR: MethodError: no method matching length(::Nothing)
The function `length` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  length(::Zygote.Grads, Any...; kwargs...)
   @ Zygote ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:17
  length(::Combinatorics.Partition)
   @ Combinatorics ~/.julia/packages/Combinatorics/Udg6X/src/youngdiagrams.jl:8
  length(::Base.MethodSpecializations)
   @ Base reflection.jl:1317
  ...

Stacktrace:
  [1] productfunc(xs::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}, dy::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/array.jl:278
  [2] (::Zygote.var"#collect_product_pullback#744"{Base.Iterators.ProductIterator{Tuple{…}}})(dy::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/array.jl:300
  [3] (::Zygote.var"#2941#back#745"{Zygote.var"#collect_product_pullback#744"{…}})(Δ::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [4] ctmrg_renormalize
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:383 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
  [6] ctmrg_iter
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:163 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CTMRGEnv{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
  [8] #221
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:118 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [10] #35
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:117 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#ad_pullback#61"{Tuple{…}, Zygote.Pullback{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:264
 [13] with_logger_pullback
    @ ~/.julia/packages/ChainRules/vdf7M/src/rulesets/Base/CoreLogging.jl:12 [inlined]
 [14] (::Zygote.ZBack{ChainRules.var"#with_logger_pullback#862"{…}})(dy::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:212
 [15] #withlevel#34
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:113 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [17] withlevel
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:107 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [19] withlevel
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:107 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [21] leading_boundary
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:115 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [23] #37
    @ ./REPL[24]:2 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:91
 [26] withgradient(f::Function, args::InfinitePEPS{TrivialTensorMap{ComplexSpace, 1, 4, Matrix{ComplexF64}}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:213
 [27] top-level scope
    @ REPL[24]:1

@lkdvos
Copy link
Member

lkdvos commented Oct 29, 2024

I'm honestly not so sure how no one has ever run into this, but I seem to have been able to circumvent the issue by just not differentiating through the collect(Iterators.product())calls. It might have something to do with our tmapgradient returning an array of nothing instead of simply nothing, but I don't really have the time to investigate that and this seems to work.

Copy link

codecov bot commented Oct 30, 2024

Codecov Report

Attention: Patch coverage is 93.54839% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/utility/diffable_threads.jl 83.33% 3 Missing ⚠️
src/operators/infinitepepo.jl 0.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/PEPSKit.jl 100.00% <ø> (ø)
src/algorithms/ctmrg/ctmrg.jl 92.39% <100.00%> (+0.99%) ⬆️
src/algorithms/ctmrg/gaugefix.jl 94.87% <100.00%> (+0.13%) ⬆️
src/algorithms/toolbox.jl 98.33% <100.00%> (+0.02%) ⬆️
src/environments/ctmrg_environments.jl 72.29% <100.00%> (+1.17%) ⬆️
src/states/infinitepeps.jl 67.77% <100.00%> (+1.49%) ⬆️
src/utility/util.jl 54.21% <100.00%> (-1.60%) ⬇️
src/operators/infinitepepo.jl 18.86% <0.00%> (ø)
src/utility/diffable_threads.jl 83.33% <83.33%> (ø)

... and 1 file with indirect coverage changes

[skip ci]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants