[PJRT] Fix stablehlo attribute parameters for buffer transpose and broadcast #19488
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
If we copy a tensor from host to device in a multi-GPU environment (i.e. sharding is enabled, not a trivial copy), the frontend call of
jax.device_put
will go toDeviceInstance::TransposeBroadcastDeviceBuffer
eventually.And this function will generate a stablehlo program to be compiled and executed then. In the generated code, two operations (
stablehlo.broadcast_in_dim
andstablehlo.transpose
) are included. According to the spec of StableHLO (and also the tblgen definitions), the parameter attributepermutation
ofstablehlo.transpose
is typedDenseI64ArrayAttr
.However, in the function
DeviceInstance::TransposeBroadcastDeviceBuffer
, it will generate some code like"stablehlo.transpose"(%x) {permutation = dense<[1,2,3]> : tensor<3xi64>} : ...
, which does not meet the definition ofstablehlo.transpose
and should be corrected as something like"stablehlo.transpose"(%x) {permutation = array<i64: 1,2,3>} : ...
.This PR is to fix it.
ci-exactly: build_packages, test_pjrt