-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Runtime dispatch overhead for small tensor contractions using the StridedBLAS
backend
#189
Comments
So I keep getting sucked into the lowest levels of the code, never having time to properly study PEPSKit.jl 😄 . |
Certainly, for |
I assume the actual coding pattern does not include a function |
Ok, it seems adding An alternative would be to explicitly encode the conjugation flag in the type domain, e.g. by having a parametric singleton struct With Cthulhu.jl, I couldn't spot an additional ambiguity/dynamic dispatch in |
We are already doing this for the |
No that doesn't work. I also tried only annotating the Another nonbreaking solution is to manually split the code based on the flag values, but there are quite a few instances where |
To be fair, if just splitting the |
Ok, I'll take a stab at it asap. |
No, in the realistic setting everything is passed through locally and all of the concrete array types are properly inferred in the call to |
After the branch splitting with an if statement the ambiguity in |
I think the dynamic dispatch in the mapreduce lowering chain comes from the fact that in But then in finally selecting |
Ok, so I am not sure how reliable this is, this is from simply timing the first call (so timing compilation) and then benchmarking the call on my computer, both including the fix from the PR, but the first one with the vanilla version of Strided.jl: julia> @time local_update!(psi0);
19.850534 seconds (34.72 M allocations: 2.236 GiB, 4.29% gc time, 100.00% compilation time)
julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 30.459 μs … 123.357 ms ┊ GC (min … max): 0.00% … 99.86%
Time (median): 41.032 μs ┊ GC (median): 0.00%
Time (mean ± σ): 71.495 μs ± 1.255 ms ┊ GC (mean ± σ): 32.50% ± 4.91%
▄▇▄▂▂▅█▅▁▂▂▄▃ ▁▂▂▂▃▂▂▂▂▂▁ ▂
██████████████▇▆▇▆▇▇▆▆▇▆▄▄▆▅▄▆▅▆▇██████████████▇▇▇▆▆▆▅▄▄▅▁▃▅ █
30.5 μs Histogram: log(frequency) by time 130 μs <
Memory estimate: 135.19 KiB, allocs estimate: 59. and then with a modified Strided.jl where julia> @time local_update!(psi0);
36.712724 seconds (27.60 M allocations: 1.740 GiB, 1.54% gc time, 100.00% compilation time)
julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 26.941 μs … 122.038 ms ┊ GC (min … max): 0.00% … 99.86%
Time (median): 38.208 μs ┊ GC (median): 0.00%
Time (mean ± σ): 69.025 μs ± 1.243 ms ┊ GC (mean ± σ): 33.64% ± 4.82%
▃▆▇▃▂▄█▇▃▁▂▄▂ ▁▁▂▂▂▂▂▂▂▁▁▁ ▂
██████████████▇▇█▇▅▆▆▇▆▆▅▅▄▃▂▄▄▅▆█████████████████▇▇▆▅▆▄▅▅▄▄ █
26.9 μs Histogram: log(frequency) by time 127 μs <
Memory estimate: 131.98 KiB, allocs estimate: 15. So allocations and runtime are a little bit lower, but compilation time is up by a factor of almost two. |
Thanks for figuring out the I can anyway make the change locally now that I understand what's going on, so as far as I'm concerned my problems are solved :) |
Ok, let us know if you encounter cases where specialising the methods in Strided do actually lead to a noticeable speedup. I will close this and merge the PR. |
We could consider putting the extra compilation time into a precompile call in TensorOperations: we keep the non-specialised generic implementation in Strided.jl, but make a specialised method in TensorOperations, which we precompile there. I do feel like this might be worthwhile, as in TensorKit, with the symmetries we call these methods quite often, so it might actually add up. We should probably test this first however, as this is quite a bit of extra complexity to maintain. |
Wouldn't this de facto mean a decoupling of TensorOperations from Strided? I don't really see how to have a separate call stack for TensorOperations without duplicating most of the Strided.jl code and simply adding the specialisations. |
I guess so, but it would also mean that not everyone necessary has to pay the compilation time price... I'm not entirely sure how much of the call stack would need to be duplicated, but for these low-level features this might just be worth it. In a perfect world it should really be possible to have non-allocating usage of As an aside, just playing games with the compiler here: what happens if you add a |
I guess that would mean that dynamic dispatch is happening in the hot loop of the kernel? Since it is the function that is called in the loop body itself that is unknown, whereas the variables in the loop are all inferred, it might be that the dynamic dispatch of selecting the actual function is hoisted out of the loop. Just speculating though. There used to be |
Also, there are other small allocations in Strided.jl, which have to do with how to block the problem or how to distribute it across threads. |
But I guess for problems of this size the multithreading does not yet kick in? |
No, and that's anyway only if there is more than one thread. I am not sure if the extra allocations are only there for multithreading though, this I should check. There might also be some for other parts of the blocking and loop reordering, even though I tried to make everything based on type-stable tuples. |
@Jutho Was the thing you tried literally just to add type parameters on those function arguments down the _mapreduce callchain? If so, I guess we can reproduce that easily and see if it helps our case. Might it be enough to only introduce the specialization on the methods that actually call the functions? Wait, I see you mention this above. I guess this is currently what happens and you get dynamic dispatch when calling the kernel. |
I think the overhead is kind of important, as it prevents Strided's version of permutedims from being a general replacement. I didn't use it in |
@amilsted and @leburgel: yes, getting rid of the runtime dynamic dispatch was obtained by fully specializing the However, I am not sure if the overhead you see is from that, or simply from the kind of computations that Strided does to find a good loop order and blocking scheme. That overhead can be nonnegligable in the case of small systems. Maybe it pays off to have a specific optimized implementation for tensor permutation. I'd certainly be interested in collaborating on that. |
Do we know how something like LoopVectorization.jl stacks up? From what I know, that should also attempt to find some loop order and blocking scheme and unrolling etc, which might be important for these kinds of systems... |
Last time I checked (which is a while ago), LoopVectorization was great at loop ordering, unrolling and inserting vectorised operations, but it did not do blocking. Octavian does this separately I think. The other major issue was that LoopVectorisation did not support composite types such as Finally, also, LoopVectorization would lead to a separate compilation stage for every new permutation. I don't know if we want to go that way? Maybe that's fine for a custom backend, but not for the default one. The other thing to revive is HPTT; I once wrote the jll package for that using BinaryBuilder, so it shouldn't be much work to write the wrappers to make an |
Also, on the Strided side, maybe there is a way to "cache" the blocking scheme etc, but I am not really sure if that is worth it. |
And of course compare to tblis as well, I don't think I benchmarked that for smaller systems (and I think it leverages hptt under the hood). Let me revive https://github.com/Jutho/TensorOperations.jl/tree/ld/benchmark as well at some point, to maybe more rigorously benchmark all implementations and changes. I have a set of permutations and contractions there, but maybe @amilsted or @leburgel, if you could supply some typical contractions and or array sizes, we can tailor our efforts a bit more towards that. |
Is it worth considering switching to Julia's own |
You should be able to try this out already, the |
@leburgel looks like BaseCopy would be the one to try. |
I'll give it a go! |
The I often benchmarked a number of randomly generated permutations which were ran just once, and then of course such approaches fail miserably. But I guess that is not a very realistic benchmark. |
When profiling some code involving many contractions with small tensors I noticed that there is a lot of overhead due to runtime dispatch and resulting allocations and garbage collection when using the
StridedBLAS
backend. I've figured out some of it but I thought it would be good to report here to see if something more can be done.I ran a profile for a dummy example contraction of small complex tensors which can sort of reproduce the issue:
which gives:
So a lot of runtime dispatch overhead in
TensorOperations.tensoradd!
andStrided._mapreduce_order!
. The effect becomes negligible for large tensor dimensions, but it turns out to be a real pain if there are a lot of these small contractions being performed.For
TensorOperations.tensoradd!
, I managed to track the problem to an ambiguity caused by patterns likeflag2op(conjA)(A)
, where the return type at compile time can be aUnion
of twoStridedView
concrete types withtypeof(identity)
andtypeof(conj)
as theirop
field types respectively. This leads to an issue here:TensorOperations.jl/src/implementation/strided.jl
Lines 99 to 103 in c1e37ec
where the last argument in the call to
Strided._mapreducedim!
is aTuple
with mixed concrete and abstract (the union described above) types in its type parameters, which seems to mess things up.At that level I managed to fix things by just splitting into two branches to de-confuse the compiler
which gets rid of the runtime dispatch in
tensoradd!
and already makes a big difference for small contractions.I haven't descended all the way down into
Strided._mapreduce_dim!
, so I don't know what the issue is there.So in the end my questions are:
conj
flags and their effect onStridedView
s that could avoid these kinds of type ambiguities?Strided._mapreduce_dim!
could be prevented by changing theStridedBLAS
backend implementation here, and if not, if it could be avoided altogether by a reasonable change to Strided.jl?The text was updated successfully, but these errors were encountered: