Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime dispatch overhead for small tensor contractions using the StridedBLAS backend #189

Closed
leburgel opened this issue Oct 2, 2024 · 33 comments

Comments

@leburgel
Copy link
Contributor

leburgel commented Oct 2, 2024

When profiling some code involving many contractions with small tensors I noticed that there is a lot of overhead due to runtime dispatch and resulting allocations and garbage collection when using the StridedBLAS backend. I've figured out some of it but I thought it would be good to report here to see if something more can be done.

I ran a profile for a dummy example contraction of small complex tensors which can sort of reproduce the issue:

using TensorOperations

T = ComplexF64
L = randn(T, 8, 8, 7)
R = randn(T, 8, 8, 7)
O = randn(T, 4, 4, 7, 7)

function local_update!(psi::Array{T,3})::Array{T,3} where {T}
    @tensor begin
        psi[-1, -2, -3] =
            psi[1, 3, 5] *
            L[-1, 1, 2] *
            O[-2, 3, 2, 4] *
            R[-3, 5, 4]
    end
    return psi
end

psi0 = randn(T, 8, 4, 8)
@profview begin
    for _ in 1:100000
        local_update!(psi0)
    end
end

which gives:
small_contract_profile

So a lot of runtime dispatch overhead in TensorOperations.tensoradd! and Strided._mapreduce_order!. The effect becomes negligible for large tensor dimensions, but it turns out to be a real pain if there are a lot of these small contractions being performed.

For TensorOperations.tensoradd!, I managed to track the problem to an ambiguity caused by patterns like flag2op(conjA)(A), where the return type at compile time can be a Union of two StridedView concrete types with typeof(identity) and typeof(conj) as their op field types respectively. This leads to an issue here:

A′ = permutedims(flag2op(conjA)(A), linearize(pA))
op1 = Base.Fix2(scale, α)
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
return C

where the last argument in the call to Strided._mapreducedim! is a Tuple with mixed concrete and abstract (the union described above) types in its type parameters, which seems to mess things up.

At that level I managed to fix things by just splitting into two branches to de-confuse the compiler

opA = flag2op(conjA)
if opA isa typeof(identity)
    A′ = permutedims(identity(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
elseif opA isa typeof(conj)
    A′ = permutedims(conj(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
end

which gets rid of the runtime dispatch in tensoradd! and already makes a big difference for small contractions.

I haven't descended all the way down into Strided._mapreduce_dim!, so I don't know what the issue is there.

So in the end my questions are:

  • Is there a cleaner way to deal with the conj flags and their effect on StridedViews that could avoid these kinds of type ambiguities?
  • Do you have any idea if the runtime dispatch in Strided._mapreduce_dim! could be prevented by changing the StridedBLAS backend implementation here, and if not, if it could be avoided altogether by a reasonable change to Strided.jl?
@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

So I keep getting sucked into the lowest levels of the code, never having time to properly study PEPSKit.jl 😄 .

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

Certainly, for conjA, the hope was that constant propagation would resolve the ambiguity, but apparently it does not. Maybe we need some more aggressive constant propagation.

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

I assume the actual coding pattern does not include a function local_update! that captures global variables, right?

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

Ok, it seems adding @constprop :aggressive before every tensorcontract! and tensoradd! definition solves the problem, but it is of course not particularly nice.

An alternative would be to explicitly encode the conjugation flag in the type domain, e.g. by having a parametric singleton struct ConjugationFlag{true}() and ConjugationFlag{false}(). I am interested in hearing the opinion of @lkdvos .

With Cthulhu.jl, I couldn't spot an additional ambiguity/dynamic dispatch in mapreduce_order!, so I am not sure if those two are related or if that is yet another issue.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 3, 2024

We are already doing this for the istemp flag with a Val{true}, so in principle I have nothing against this change. It's a bit unfortunate that this has to be breaking, so we might want to investigate a bit more...
Do you know if we can do const propagation for single arguments?

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

No that doesn't work. I also tried only annotating the flag2op function, so that also doesn't work. Maybe there is some in between balance where you only need to annotate some of the tensor operation definitions, but that can only be found by trial and error I think.

Another nonbreaking solution is to manually split the code based on the flag values, but there are quite a few instances where flag2op is called, and in contract there are two of those, so then one has 4 different cases.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 3, 2024

To be fair, if just splitting the flag2op into if statements works, that might be the cleanest. I do see how that function is inherently type unstable, so either this information needs to be type domain or we do a manual Union-splitting type construction, which I had hoped the compiler would have figured out but apparently it doesn't.

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

Ok, I'll take a stab at it asap.

@leburgel
Copy link
Contributor Author

leburgel commented Oct 3, 2024

I assume the actual coding pattern does not include a function local_update! that captures global variables, right?

No, in the realistic setting everything is passed through locally and all of the concrete array types are properly inferred in the call to tensorcontract!. I wasn't really paying attention with to the globals in the example, but it runs into the same problem lower down.

@leburgel
Copy link
Contributor Author

leburgel commented Oct 3, 2024

With Cthulhu.jl, I couldn't spot an additional ambiguity/dynamic dispatch in mapreduce_order!, so I am not sure if those two are related or if that is yet another issue.

After the branch splitting with an if statement the ambiguity in tensoradd! is completely removed, but the runtime dispatch in mapreduce_order! remains. So that might be another issue, but I couldn't really figure it out.

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

I think the dynamic dispatch in the mapreduce lowering chain comes from the fact that in _mapreduce_fuse! and _mapreduce_block!, the compiler doesn't specify on the function arguments f, op and initop. I think this is a known behaviour of the Julia compiler to reduce compile times: if you want to specify function arguments on the specific function, you should actually make it a type parameter in your method. In this case, the non-specialisation is in principle fine, as this has no effect on any of the types of the variables in those function bodies, so everything is perfectly type stable.

But then in finally selecting _mapreduce_kernel, that being a generated function, it does again specialise on f etc, and so there is some dynamic dispatch going on. Not sure what the impact of that is.

@Jutho
Copy link
Owner

Jutho commented Oct 3, 2024

Ok, so I am not sure how reliable this is, this is from simply timing the first call (so timing compilation) and then benchmarking the call on my computer, both including the fix from the PR, but the first one with the vanilla version of Strided.jl:

julia> @time local_update!(psi0);
 19.850534 seconds (34.72 M allocations: 2.236 GiB, 4.29% gc time, 100.00% compilation time)

julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  30.459 μs … 123.357 ms  ┊ GC (min … max):  0.00% … 99.86%
 Time  (median):     41.032 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   71.495 μs ±   1.255 ms  ┊ GC (mean ± σ):  32.50% ±  4.91%

  ▄▇▄▂▂▅█▅▁▂▂▄▃                     ▁▂▂▂▃▂▂▂▂▂▁                ▂
  ██████████████▇▆▇▆▇▇▆▆▇▆▄▄▆▅▄▆▅▆▇██████████████▇▇▇▆▆▆▅▄▄▅▁▃▅ █
  30.5 μs       Histogram: log(frequency) by time       130 μs <

 Memory estimate: 135.19 KiB, allocs estimate: 59.

and then with a modified Strided.jl where f::F1, op::F2 and initop::F3 are now explicit type parameters of the relevant methods:

julia> @time local_update!(psi0);
 36.712724 seconds (27.60 M allocations: 1.740 GiB, 1.54% gc time, 100.00% compilation time)

julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  26.941 μs … 122.038 ms  ┊ GC (min … max):  0.00% … 99.86%
 Time  (median):     38.208 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   69.025 μs ±   1.243 ms  ┊ GC (mean ± σ):  33.64% ±  4.82%

  ▃▆▇▃▂▄█▇▃▁▂▄▂                     ▁▁▂▂▂▂▂▂▂▁▁▁               ▂
  ██████████████▇▇█▇▅▆▆▇▆▆▅▅▄▃▂▄▄▅▆█████████████████▇▇▆▅▆▄▅▅▄▄ █
  26.9 μs       Histogram: log(frequency) by time       127 μs <

 Memory estimate: 131.98 KiB, allocs estimate: 15.

So allocations and runtime are a little bit lower, but compilation time is up by a factor of almost two.

@leburgel
Copy link
Contributor Author

leburgel commented Oct 4, 2024

Thanks for figuring out the mapreduce issue! It is a small effect and a lot more compilation time, so I'm not sure if you think it's worth actually making the change in Strided.jl?

I can anyway make the change locally now that I understand what's going on, so as far as I'm concerned my problems are solved :)

@Jutho
Copy link
Owner

Jutho commented Oct 4, 2024

Ok, let us know if you encounter cases where specialising the methods in Strided do actually lead to a noticeable speedup. I will close this and merge the PR.

@Jutho Jutho closed this as completed Oct 4, 2024
@lkdvos
Copy link
Collaborator

lkdvos commented Oct 4, 2024

We could consider putting the extra compilation time into a precompile call in TensorOperations: we keep the non-specialised generic implementation in Strided.jl, but make a specialised method in TensorOperations, which we precompile there. I do feel like this might be worthwhile, as in TensorKit, with the symmetries we call these methods quite often, so it might actually add up.

We should probably test this first however, as this is quite a bit of extra complexity to maintain.

@Jutho
Copy link
Owner

Jutho commented Oct 4, 2024

Wouldn't this de facto mean a decoupling of TensorOperations from Strided? I don't really see how to have a separate call stack for TensorOperations without duplicating most of the Strided.jl code and simply adding the specialisations.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 4, 2024

I guess so, but it would also mean that not everyone necessary has to pay the compilation time price... I'm not entirely sure how much of the call stack would need to be duplicated, but for these low-level features this might just be worth it.
Of course, it's also an option to just keep the compile time in Strided, precompile the specific TensorOperations kernels to mitigate that, and then see if it becomes hindering for other Strided usage?

In a perfect world it should really be possible to have non-allocating usage of @tensor with either preallocated temporaries or Bumper.jl, which is now not the case. Keeping in mind that this function is called once for every subblock associated to a fusiontree, this can easily become quite a large amount of overhead. We should probably try out some Hubbard-type symmetry, but it's not urgent, I can add it to the to-do:)

As an aside, just playing games with the compiler here: what happens if you add a @nospecialize in the kernel as well? I guess that would get rid of the type instability but hurt performance too much? Does that even work?

@Jutho
Copy link
Owner

Jutho commented Oct 4, 2024

I guess that would mean that dynamic dispatch is happening in the hot loop of the kernel? Since it is the function that is called in the loop body itself that is unknown, whereas the variables in the loop are all inferred, it might be that the dynamic dispatch of selecting the actual function is hoisted out of the loop. Just speculating though.

There used to be @nospecialize annotators before the function arguments in the _mapreduce methods (not sure if that included the kernel), and at some point I removed them and was happy to see that this did not have any significant effect compilation time. I guess I now understand why 😄 .

@Jutho
Copy link
Owner

Jutho commented Oct 4, 2024

Also, there are other small allocations in Strided.jl, which have to do with how to block the problem or how to distribute it across threads.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 4, 2024

But I guess for problems of this size the multithreading does not yet kick in?

@Jutho
Copy link
Owner

Jutho commented Oct 4, 2024

No, and that's anyway only if there is more than one thread. I am not sure if the extra allocations are only there for multithreading though, this I should check. There might also be some for other parts of the blocking and loop reordering, even though I tried to make everything based on type-stable tuples.

@amilsted
Copy link

amilsted commented Oct 4, 2024

@Jutho Was the thing you tried literally just to add type parameters on those function arguments down the _mapreduce callchain? If so, I guess we can reproduce that easily and see if it helps our case. Might it be enough to only introduce the specialization on the methods that actually call the functions?

Wait, I see you mention this above. I guess this is currently what happens and you get dynamic dispatch when calling the kernel.

@amilsted
Copy link

amilsted commented Oct 4, 2024

I think the overhead is kind of important, as it prevents Strided's version of permutedims from being a general replacement. I didn't use it in QuantumOpticsBase in the LazyTensor implementation because this overhead is noticeable for small systems. Maybe some explicit inlining could bring down the compile time a bit?

@Jutho
Copy link
Owner

Jutho commented Oct 6, 2024

@amilsted and @leburgel: yes, getting rid of the runtime dynamic dispatch was obtained by fully specializing the f, op and initop arguments by giving them type parameters in the method. I have pushed this to a branch to facilitate trying out:
https://github.com/Jutho/Strided.jl/tree/jh/specialize

However, I am not sure if the overhead you see is from that, or simply from the kind of computations that Strided does to find a good loop order and blocking scheme. That overhead can be nonnegligable in the case of small systems. Maybe it pays off to have a specific optimized implementation for tensor permutation. I'd certainly be interested in collaborating on that.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 6, 2024

Do we know how something like LoopVectorization.jl stacks up? From what I know, that should also attempt to find some loop order and blocking scheme and unrolling etc, which might be important for these kinds of systems...

@Jutho
Copy link
Owner

Jutho commented Oct 6, 2024

Last time I checked (which is a while ago), LoopVectorization was great at loop ordering, unrolling and inserting vectorised operations, but it did not do blocking. Octavian does this separately I think. The other major issue was that LoopVectorisation did not support composite types such as Complex, which is kinda important :-).

Finally, also, LoopVectorization would lead to a separate compilation stage for every new permutation. I don't know if we want to go that way? Maybe that's fine for a custom backend, but not for the default one.

The other thing to revive is HPTT; I once wrote the jll package for that using BinaryBuilder, so it shouldn't be much work to write the wrappers to make an HPTTBLAS backend (meaning HPTT for permutations and (permute+BLAS) for contractions).

@Jutho
Copy link
Owner

Jutho commented Oct 6, 2024

Also, on the Strided side, maybe there is a way to "cache" the blocking scheme etc, but I am not really sure if that is worth it.

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 6, 2024

And of course compare to tblis as well, I don't think I benchmarked that for smaller systems (and I think it leverages hptt under the hood).

Let me revive https://github.com/Jutho/TensorOperations.jl/tree/ld/benchmark as well at some point, to maybe more rigorously benchmark all implementations and changes. I have a set of permutations and contractions there, but maybe @amilsted or @leburgel, if you could supply some typical contractions and or array sizes, we can tailor our efforts a bit more towards that.

@amilsted
Copy link

amilsted commented Oct 7, 2024

Is it worth considering switching to Julia's own permutedims for small arrays?

@lkdvos
Copy link
Collaborator

lkdvos commented Oct 7, 2024

You should be able to try this out already, the BaseCopy and BaseView backends should do this if I'm not mistaken

@amilsted
Copy link

amilsted commented Oct 7, 2024

@leburgel looks like BaseCopy would be the one to try.

@leburgel
Copy link
Contributor Author

leburgel commented Oct 7, 2024

I'll give it a go!

@Jutho
Copy link
Owner

Jutho commented Oct 7, 2024

The Base... methods were not really meant for performance. The problem with using Base.permutedims is that it doesn't support the alpha and beta parameters. And while in the case of a tensor network contraction, we don't typically need them (i.e. they take there default value), it is annoying to have to check and correct for that. Maybe copying the Base implementation and generalising it to include the alpha en beta is one way forward. But as far as I remember from looking at quite some time ago, it goes via PermuteDimsArray, which encodes the permutation in its type parameters, and thus also leads to new compilation for every new permutation. Maybe that is the way to go, and it is not excessive in real life applications.

I often benchmarked a number of randomly generated permutations which were ran just once, and then of course such approaches fail miserably. But I guess that is not a very realistic benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants