-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multi-threading using OhMyThreads and make it differentiable #70
base: master
Are you sure you want to change the base?
Conversation
@lkdvos I could use some help on this one... I still find writing rrules very confusing sometimes, so could you maybe take a look at the rrule for Other than that we have to think about how to pass along the threading kwargs to the I'll also add some rrule tests in the end. |
I think tforeach is going to be a bit hard, because it has no outputs, and is thus necessarily in-place... I have no clue how the zygote buffer magic works, so I can't really say I know how to deal with that either. I should have some more time to think this through next week though! For the global variables, I would maybe suggest ScopedVariables.jl instead, this is a little more flexible and shouldn't incur too much runtime costs |
How about we just stick to
Wasn't aware of ScopedValues.jl yet, that looks like a great solution. But I don't quite understand the necessity of a scoped value here since we never need to access the threading settings inside a multi-threaded map, right? In any case, we can probably just have a global scoped Anyways, I will give these things a go and then we can review next week, when you have time :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I should have elaborated a bit more on what I had in mind with the scoped values: I would keep the threading strategies in scoped values, such that they can always be accessed as if they are global values. (this prevents them from bloating all of our algorithms)
The benefit of having them as a scoped value however means that users could still change them by calling the peps function from within a scope with a modified scoped value, thus changing the scheduler.
src/utility/diffable_threads.jl
Outdated
backevals = tmap(CartesianIndices(A); kwargs...) do idx | ||
last(el_rrules[idx])(dy[idx]) | ||
end | ||
df = ProjectTo(f)(sum(first, backevals)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow it should be possible to do a tmapreduce
and combine all of this, but I'll try and look into it a bit more
…Defaults docstring
I was trying to fix the Zygote error but with no luck, I really don't know how to handle the
|
I'm honestly not so sure how no one has ever run into this, but I seem to have been able to circumvent the issue by just not differentiating through the |
cbc1fe7
to
f3d7192
Compare
f3d7192
to
cab6c69
Compare
Codecov ReportAttention: Patch coverage is
|
Here we'll replace the
@fwdthreads
macro withtmap
andforeach
calls. Additionally, we will code up reverse rules such that the backwards pass also runs in parallel.