Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the bug that for block_k=16 mma, the compilation crash on Ampere. #4768

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

bingyizh233
Copy link

@bingyizh233 bingyizh233 commented Sep 20, 2024

The origin issue is reported here #3435: The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1. For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> Investigation shows that the bug happens in TritonGPUToLLVM pass. in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line (here). The code unpack extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at here, at least 4 rep will be loaded. Therefore, in our patch, extra loaded elements are dropped in the corner case.

The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.

Complete the following tasks before sending your PR, and replace [ ] with
[x] to indicate you have done them.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • [] This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

The origin issue is reported here: triton-lang#3435 The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1. For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> Investigation shows that the bug happens in TritonGPUToLLVM pass. in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line 186-194. The code unpack extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded. Therefore, in our patch, extra loaded elements are dropped in the corner case.
@Jokeren
Copy link
Contributor

Jokeren commented Sep 20, 2024

Hi, thanks for your information. However, a PR like this is not acceptable.

Most importantly, I don't see any test case included

@gflegar gflegar assigned gflegar and unassigned gflegar Sep 20, 2024
@gflegar gflegar requested review from chsigg and removed request for ptillet September 20, 2024 16:26
@gflegar
Copy link
Collaborator

gflegar commented Sep 20, 2024

We can take a first pass at this from the Google side, since we asked Nvidia for help with this, and once it's in an acceptable state we will assign to a core maintainer for final approval. @Moerafaat can take a first look (unfortunately, I don't think I can assign him officially as the reviewer), since he's most familiar with that part of the codebase. And then @chsigg can serve as the final quality gate, as agreed between OAI & Google.

@bingyizh233
Copy link
Author

Hi, thanks for your information. However, a PR like this is not acceptable.

Most importantly, I don't see any test case included

Thank you for your feedback! I’ll work on adding the test cases. Additionally, could you please let me know if there’s anything else that might be missing?

@Jokeren
Copy link
Contributor

Jokeren commented Sep 20, 2024

The mixed precision stuff is quite fragile as far as I know. I'm not sure as I haven't been taken a look for a while.

test/Conversion/tritongpu_to_llvm_ampere.mlir.TYhis

This is a the test case for converting tensor from s8 to fp16 when the
first dimension of the tensor==16 and opidx==1. Previously, the
compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere
GPU. The new code patch resolves the issue. This new test case is to
verify the crash does not exist.
    This is a the test case for converting tensor from s8 to fp16 when the
    first dimension of the tensor==16 and opidx==1. Previously, the
    compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere
    GPU. The new code patch resolves the issue. This new test case is to
    verify the crash does not exist
        This is a the test case for converting tensor from s8 to fp16 when the
        first dimension of the tensor==16 and opidx==1. Previously, the
        compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere
        GPU. The new code patch resolves the issue. This new test case is to
        verify the crash does not exist
@Moerafaat
Copy link
Contributor

Moerafaat commented Sep 23, 2024

Thanks for the effort here, I will take a deeper look. Just a few notes for now:

  • Can you please provide a permalink for this? "lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded."
  • At the time of reporting the issue, this was one instance of failures that happened due to "size mismatch during packing in LLVM". Can you verify that the matmul with swapped input order works, i.e. S8xF16 ?
  • The way we currently work around this issue is in the auto-tuner. It is currently handled in a generic way (i.e. not specific to S8 for example). You can see here that we are disabling tiling configurations that would potentially lead to running into this issue. Ideally a solution would be accepted if we can remove this restriction in the auto-tuner without any failing tests. Can you include this as part of your change as well and see if we get any failures? Also note that we have a similar restriction here for sparsity that we would like to remove as well.

@Moerafaat
Copy link
Contributor

I was also trying to see other scenarios that we ran into that could be classified as the same issue. For example, looking at this sample HLO:

triton_dot {
  parameter = f8e4m3fn[32,32]{1,0} parameter(0)
  parameter.1 = f8e4m3fn[32,32]{1,0} parameter(1)
  ROOT dot = f32[32,32]{1,0} dot(parameter,parameter.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}

ENTRY e {
  parameter = f8e4m3fn[32,32]{1,0} parameter(0)
  parameter.1 = f8e4m3fn[32,32]{1,0} parameter(1)
  ROOT dot = f32[32,32]{1,0} fusion(parameter,parameter.1), kind=kCustom, calls=triton_dot,
    backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":16,"block_k":16,"split_k":1,"num_stages":2,"num_warps":2,"num_ctas":1}}}

With your change this will fail with the same "size mismatch" error. I suspect that in this case it would be because the adjustment will not apply to the left-hand side because of the restriction on the OpIdx. Is that restriction necessary or is something else the reason for failing in this example?

@Moerafaat
Copy link
Contributor

Moerafaat commented Sep 23, 2024

Another example where the current solution would fail is for this example:

  parameter_0 = s8[64,64]{1,0} parameter(0)
  convert.2 = bf16[64,64]{1,0} convert(parameter_0)
  parameter_1 = bf16[128,32]{1,0} parameter(1)
  parameter_2 = u16[64,8]{1,0} parameter(2)
  ROOT dot.0 = bf16[64,32]{1,0} dot(convert.2, parameter_1, parameter_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4

This one goes through tt.sparse_dot (only exists if you use XLA's Triton). It's fine if you can't reproduce this example, but what I could also observe with this example for some tilings (for example shape[0] = 64, shape[1] = 16 for OpIdx = 0) is that this will also run into the same failure. So I'm questioning if the condition strictly applies for shape[0]?

@Moerafaat
Copy link
Contributor

Also make sure to run pre-commit on the change.

[128, 128, 128, 64, 16, 16, torch.float16, torch.int8]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 8, reason="Requires compute capability == 8")
def test_gemm_mixed_dtype(M, N, K, block_m, block_n, block_k, dtype_a, dtype_b):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment explaining the purpose of this test? There are mixed-type tests that exist. The emphasis here should be focused on mixed-types with small tile sizes. I would suggest also renaming the file to indicate emphasis on small tiles.

Also can we add more variety in the parametrization? Perhaps use fp8 types as well instead of just int8.


#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @amperes8tofp16conversion(%1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Triton's conversion on the naming here is snake_case. Please adjust.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have adjusted the naming following snake_case and added a test case when 'opIdx == 1 and shape[1] == 16'.

ret.push_back(values[i + 13]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
if(in_shape[0] == 16 && inEncoding.getOpIdx() == 1){ // In the corner case where in_shape[0] == 16 and getOpIdx() == 1, extra elements will be loaded. It is necessary to discard these additional elements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the comments I left on the PR, I'm questioning the condition here as some cases will not match this and fail for the same reason.

I tried to change this condition to
in_shape[0] == 16 || out_shape == 16 instead, and it passes the 2 cases I indicated previously that fail. That being said, I'm not sure if this is the correct way to handle it or not.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the information. I have looked into this. The condition to (in_shape[0] == 16 && inEncoding.getOpIdx() == 1) && (in_shape[1] == 16 && inEncoding.getOpIdx() == 0) resolve the issues and might be a better way, including s8xfp16 case. I will look at the other cases you mentions.

@@ -56,23 +58,37 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
}
if (inBitWidth == 8 && ouBitWidth == 16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fix only applicable to conversions from 8-bit to 16-bit or is this extendable to other combinations as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for the conversions from 8-bit to 16-bit. Other datatype (e.g., 32-bit) has different data layout for triton_gpu.dot_op.

@bingyizh233
Copy link
Author

Thanks for the effort here, I will take a deeper look. Just a few notes for now:

  • Can you please provide a permalink for this? "lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded."
  • At the time of reporting the issue, this was one instance of failures that happened due to "size mismatch during packing in LLVM". Can you verify that the matmul with swapped input order works, i.e. S8xF16 ?
  • The way we currently work around this issue is in the auto-tuner. It is currently handled in a generic way (i.e. not specific to S8 for example). You can see here that we are disabling tiling configurations that would potentially lead to running into this issue. Ideally a solution would be accepted if we can remove this restriction in the auto-tuner without any failing tests. Can you include this as part of your change as well and see if we get any failures? Also note that we have a similar restriction here for sparsity that we would like to remove as well.

Thanks for the detailed comments. I update the commit log by adding the permalink. I will test other cases and take a deep look into it. Could you provide an example code for running this gemm_fusion_autotuner ?

@bingyizh233
Copy link
Author

Thanks for the effort here, I will take a deeper look. Just a few notes for now:

  • Can you please provide a permalink for this? "lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded."
  • At the time of reporting the issue, this was one instance of failures that happened due to "size mismatch during packing in LLVM". Can you verify that the matmul with swapped input order works, i.e. S8xF16 ?
  • The way we currently work around this issue is in the auto-tuner. It is currently handled in a generic way (i.e. not specific to S8 for example). You can see here that we are disabling tiling configurations that would potentially lead to running into this issue. Ideally a solution would be accepted if we can remove this restriction in the auto-tuner without any failing tests. Can you include this as part of your change as well and see if we get any failures? Also note that we have a similar restriction here for sparsity that we would like to remove as well.

Can I ask what is the data tiling configurations restriction in here besides triton's restriction? As far as I know, triton has the restriction that block_m >=16, block_n >=16, block_k>=16, and block_m, block_n, block_k should be power 2. Besides the triton restriction and (block_k != 16) which leads to crash. What is other restriction in the autotuner? Thanks!

@Moerafaat
Copy link
Contributor

Thanks for the effort here, I will take a deeper look. Just a few notes for now:

  • Can you please provide a permalink for this? "lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded."
  • At the time of reporting the issue, this was one instance of failures that happened due to "size mismatch during packing in LLVM". Can you verify that the matmul with swapped input order works, i.e. S8xF16 ?
  • The way we currently work around this issue is in the auto-tuner. It is currently handled in a generic way (i.e. not specific to S8 for example). You can see here that we are disabling tiling configurations that would potentially lead to running into this issue. Ideally a solution would be accepted if we can remove this restriction in the auto-tuner without any failing tests. Can you include this as part of your change as well and see if we get any failures? Also note that we have a similar restriction here for sparsity that we would like to remove as well.

Thanks for the detailed comments. I update the commit log by adding the permalink. I will test other cases and take a deep look into it. Could you provide an example code for running this gemm_fusion_autotuner ?

You can run tests here that are purposed for this. You don't have to worry about coverage for now, I can test internally and report to you other issues if they arise.

@Moerafaat
Copy link
Contributor

Thanks for the effort here, I will take a deeper look. Just a few notes for now:

  • Can you please provide a permalink for this? "lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded."
  • At the time of reporting the issue, this was one instance of failures that happened due to "size mismatch during packing in LLVM". Can you verify that the matmul with swapped input order works, i.e. S8xF16 ?
  • The way we currently work around this issue is in the auto-tuner. It is currently handled in a generic way (i.e. not specific to S8 for example). You can see here that we are disabling tiling configurations that would potentially lead to running into this issue. Ideally a solution would be accepted if we can remove this restriction in the auto-tuner without any failing tests. Can you include this as part of your change as well and see if we get any failures? Also note that we have a similar restriction here for sparsity that we would like to remove as well.

Can I ask what is the data tiling configurations restriction in here besides triton's restriction? As far as I know, triton has the restriction that block_m >=16, block_n >=16, block_k>=16, and block_m, block_n, block_k should be power 2. Besides the triton restriction and (block_k != 16) which leads to crash. What is other restriction in the autotuner? Thanks!

The restriction here is Triton specific only at the mentioned line. The idea here is that block_k = 16 is specific to cases where we have 8-bit width types. The way it is written here is just the generalization for any bit-width.

ret.push_back(values[i + 13]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
if ((in_shape[0] == 16 && inEncoding.getOpIdx() == 1) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be more readable:

    bool loadsExtraElements = in_shape[1-inEncoding.getOpIdx()] == 16;
    for (unsigned i = 0; i < values.size(); i += 16) {
      ret.push_back(values[i + 0]);
      ret.push_back(values[i + 1]);
      ret.push_back(values[i + 2]);
      ret.push_back(values[i + 3]);
      ret.push_back(values[i + 8]);
      ret.push_back(values[i + 9]);
      ret.push_back(values[i + 10]);
      ret.push_back(values[i + 11]);
      if (loadsExtraElements)
        continue;  // Discard elements that aren't needed.
      ret.push_back(values[i + 4]);
      ret.push_back(values[i + 5]);
      ret.push_back(values[i + 6]);
      ret.push_back(values[i + 7]);
      ret.push_back(values[i + 12]);
      ret.push_back(values[i + 13]);
      ret.push_back(values[i + 14]);
      ret.push_back(values[i + 15]);
    }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions. I have restructured the code in the newest commit. All the test and pre-commit have been done again.

@Moerafaat
Copy link
Contributor

After addressing the last comment from @chsigg I think the PR would be ready for OAI to have a look.

@bingyizh233
Copy link
Author

After addressing the last comment from @chsigg I think the PR would be ready for OAI to have a look.

I have restructured the code in the newest commit. All the test and pre-commit have been done again.

@chsigg chsigg marked this pull request as ready for review September 27, 2024 14:38
@chsigg
Copy link
Collaborator

chsigg commented Sep 27, 2024

@Jokeren, this should be ready for review now. Would you mind to take a look? Thanks!

@Jokeren
Copy link
Contributor

Jokeren commented Oct 1, 2024

cc @ThomasRaoux

@ThomasRaoux
Copy link
Collaborator

This seems like a very ad hoc workaround for the crash.
It is clear we need to re-work this as we shouldn't have to move data around when doing type conversion. This code is definitely a workaround for a more fundamental problem with dot_encoding layouts. I'm a bit concerns about adding more hack on top of of the existing hack to make things work.

I assume you are not blocked by this on the XLA side as you have this patch already?
Would you be interested in helping doing a proper fix for the dot_encoding layout instead? WDYT?

@Jokeren
Copy link
Contributor

Jokeren commented Oct 2, 2024

IMHO, I think it's fine to merge if it's an urgent issue that prevents you guys from moving forward, and we can use your code as a reference while we're refactoring the dot operand layout.

@Moerafaat
Copy link
Contributor

@ThomasRaoux it is indeed working around the issue, but not in an invasive manner I would say as the change is small and local. We work around it in XLA by entirely omitting tiling configurations that include such cases. We really would like to include these sizes in our tiling configs and this fix will allow us to do so.

We have many cases where we maintain some fixes internally (some are hacks), but ideally where we can, it would be great to have the changes upstream to minimize the differences. This allows us to also monitor issues and collaborate easily if this diff is smaller.

@Moerafaat
Copy link
Contributor

@ThomasRaoux just a friendly ping on this thread when you have time to take a look :)

@bingyizh233
Copy link
Author

Trying to do rebase and test it on Ampere. However, currently pytest of upstream has errors on ampere platform. Report the issue here #4990

ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
if (loadsExtraElements)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also repeat extra elements

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I did not get the words "repeat extra elements". Any clarification is appreciated. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Push i=0 to 7 again into ret

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

      ret.push_back(values[i]);
      ret.push_back(values[i + 1]);
      ret.push_back(values[i + 4]);
      ret.push_back(values[i + 5]);
      ret.push_back(values[i + 2]);
      ret.push_back(values[i + 3]);
      ret.push_back(values[i + 6]);
      ret.push_back(values[i + 7]);
      if (loadsExtraElements){
         ret.push_back(values[i]);
         ret.push_back(values[i + 1]);
         ret.push_back(values[i + 4]);
         ret.push_back(values[i + 5]);
         ret.push_back(values[i + 2]);
         ret.push_back(values[i + 3]);
         ret.push_back(values[i + 6]);
         ret.push_back(values[i + 7]);
        continue; // Discard elements that aren't needed.
     }

I just tried to push the i=0 -7 again. However, it gives error when testing python/test/unit/ampere/test_gemm_mixed_dtype_small_tile.py. Here, packing extra elements may cause unmatch for number of registers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jokeren I'm not sure I understand what you mean. The intent of the change is to actually not pack the extra elements as they result in a size mismatch during lowering to LLVM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait a bit until this PR has been landed. #5044

Maybe a lot of issues you guys saw previously will be gone

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We merged this change internally (via a patch) for now as we are facing issues now with Triton at main (including the change in #5044.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it today

copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695311919
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 11, 2024
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695311919
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants