diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 0a9e32b4c5..0da2d5ad85 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d" +LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24" -LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f" +LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index e793a5b699..1ccf47891b 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -e2402615a5a76d46a433dfcc1de10b38a1263c9d +aa65f93b71dee8cacb22be1957673c8be6a3ec24 diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index bafcf5b606..e2ff617e10 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -780,5 +780,22 @@ bool isValidQuantizedDimension(Type type) { numScales == rankedType.getDimSize(quantDim)); } +bool hasSingleBoundedDimension(Type type) { + RankedTensorType rankedType = dyn_cast(type); + auto boundedAttr = + dyn_cast_or_null(rankedType.getEncoding()); + if (!boundedAttr) return false; + + // Count if bounded attr size is not kDynamic + int64_t numBoundedDims = llvm::count_if( + boundedAttr.getBounds(), + [](int64_t bound) { return !ShapedType::isDynamic(bound); }); + // Also check that there are only bounded dims and no unbounded dims. + int64_t numDynamicDims = llvm::count_if( + rankedType.getShape(), + [](int64_t bound) { return ShapedType::isDynamic(bound); }); + return numBoundedDims == 1 && numDynamicDims == 1; +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index a65d54700f..36bfafa317 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType); // mentioned in the StableHLO specification. bool isValidQuantizedDimension(Type type); +// Returns true if the given type has a single bounded dimension. +bool hasSingleBoundedDimension(Type type); + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp std::pair inferConcatenatedDimAndBound(int64_t leftSize, diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index b995fcda31..73894dc3da 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -30,6 +30,20 @@ def I32RankedTensor : RankedTensorOf<[I32]>; def UI32RankedTensor : RankedTensorOf<[UI32]>; +//===----------------------------------------------------------------------===// +// HLO type constraints. +//===----------------------------------------------------------------------===// + +// Note: Bounded dynamisms is largely unspecced and this feature needs more +// thoguht as it is adopted to modern frameworks. The current support is +// designed to allow existing TF programs to be representable in StableHLO and +// is subject to change as a formal design for boudned dynamism is developed. +def HLO_HasSingleBoundedDimensionPred + : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">; + +def HLO_HasStaticOrSingleBoundedShapePred + : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>; + //===----------------------------------------------------------------------===// // HLO type definitions. //===----------------------------------------------------------------------===// @@ -267,6 +281,9 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[ def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">; +def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">; + def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>; def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>; diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index a83b696433..281d5191cb 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1980,7 +1980,7 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim", DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ ); - let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); + let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); let hasVerifier = 1; @@ -2732,7 +2732,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape", let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand); - let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); + let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index daabe04d9a..9bde090cd3 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3724,9 +3724,8 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, Value operand, ArrayRef broadcastDimensions, Value result) { - auto operandType = cast(operand.getType()); - // broadcast_in_dim_c1 + auto operandType = cast(operand.getType()); if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType, result.getType()))) return failure(); @@ -4658,11 +4657,12 @@ LogicalResult verifyReshapeOp(std::optional location, Value operand, Value result) { // If the operand type is dynamically shaped there is nothing to verify. auto operandTy = cast(operand.getType()); - if (!operandTy.hasStaticShape()) return success(); + auto resultTy = cast(result.getType()); + if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) + return success(); // If the operand type is statically shaped (not required) the number of // elements must match that of the result type. - auto resultTy = cast(result.getType()); int64_t numResultElements = resultTy.getNumElements(); int64_t numOperandElements = operandTy.getNumElements(); if (numResultElements != numOperandElements) diff --git a/stablehlo/dialect/Version.cpp b/stablehlo/dialect/Version.cpp index 27cc196082..563a1e94c9 100644 --- a/stablehlo/dialect/Version.cpp +++ b/stablehlo/dialect/Version.cpp @@ -75,7 +75,7 @@ FailureOr Version::getBytecodeVersion() const { Version Version::fromCompatibilityRequirement( CompatibilityRequirement requirement) { // Compatibility requirement versions can be updated as needed, as long as the - // version satisifies the requirement. + // version satisfies the requirement. // The time frames used are from the date that the release was tagged on, not // merged. The tag date is when the version has been verified and exported to // XLA. See: https://github.com/openxla/stablehlo/tags diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index 3654af813e..96b5c062de 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -25,7 +25,7 @@ limitations under the License. #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Attributes.h" @@ -40,7 +40,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" -#include "stablehlo/dialect/AssemblyFormat.h" +#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: keep #include "stablehlo/dialect/Version.h" #include "stablehlo/dialect/VhloBytecode.h" #include "stablehlo/dialect/VhloTypes.h" @@ -184,12 +184,13 @@ ParseResult parseFunctionBody(OpAsmParser& parser, Attribute& name, return success(); } -void TensorV1Attr::print(mlir::AsmPrinter& p) const { - p << '<' - << DenseIntOrFPElementsAttr::getFromRawBuffer( - llvm::cast(convertTypeToBuiltinForPrint(getType())), - getData()) - << '>'; +void TensorV1Attr::print(mlir::AsmPrinter& odsPrinter) const { + odsPrinter << '<' + << DenseIntOrFPElementsAttr::getFromRawBuffer( + llvm::cast( + convertTypeToBuiltinForPrint(getType())), + getData()) + << '>'; } // Parse tensor elements using DenseIntOrFPElementsAttr printing. diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index 6f06a12de4..89d5de96b2 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -48,13 +48,12 @@ bool isSupportedIntegerType(Type type) { } bool isSupportedFloatType(Type type) { - return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || - type.isFloat6E3M2FN() || type.isFloat8E3M4() || - type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() || - type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() || - type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() || - type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() || - type.isF32() || type.isF64(); + return llvm::isa< + mlir::Float4E2M1FNType, mlir::Float6E2M3FNType, mlir::Float6E3M2FNType, + mlir::Float8E3M4Type, mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3Type, + mlir::Float8E4M3FNType, mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type, + mlir::Float8E5M2FNUZType, mlir::Float8E8M0FNUType, mlir::Float16Type, + mlir::BFloat16Type, mlir::Float32Type, mlir::Float64Type>(type); } bool isSupportedComplexType(Type type) { diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 7bf113c52c..cd62b1a9f1 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1274,6 +1274,22 @@ func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // ----- +// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1 +func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor) -> tensor<1x3xi32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor) -> tensor<1x3xi32> + return %0 : tensor<1x3xi32> +} + +// ----- + +func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor { + // expected-error@+1 {{must be statically shaped or single bounded dimension tensor}} + %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<3xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // Regression test for b/180052624, where this was improperly marked as an // invalid stablehlo.broadcast_in_dim op. // CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand diff --git a/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir b/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir new file mode 100644 index 0000000000..38a6113ceb --- /dev/null +++ b/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir @@ -0,0 +1,63 @@ +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file captures some quirks to bounded dynamism in StableHLO that are +// included to allow StableHLO to repersent existing TF programs. + +// CHECK-LABEL: reshape_with_single_bounded_dimension +func.func @reshape_with_single_bounded_dimension(%arg0: tensor>) -> tensor<2x?xf32, #stablehlo.bounds> { + %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<2x?xf32, #stablehlo.bounds> + // CHECK: return {{.*}} #stablehlo.bounds + return %0 : tensor<2x?xf32, #stablehlo.bounds> +} + +// ----- + +// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension +func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor>) -> tensor<1x?xf32, #stablehlo.bounds> { + %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #stablehlo.bounds> + // CHECK: return {{.*}} #stablehlo.bounds + return %0 : tensor<1x?xf32, #stablehlo.bounds> +} + +// ----- + +func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor> { + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} + %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension +func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> + // CHECK: return {{.*}} #stablehlo.bounds + return %0 : tensor<2x1x?xf32, #stablehlo.bounds> +} + +// ----- + +func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> { + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> + return %0 : tensor<2x?x?xf32, #stablehlo.bounds> +} + +// ----- + +// CHECK-LABEL: constant_splat_broadcast +func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds> { + %0 = stablehlo.constant dense<1.0> : tensor + %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1x?xf32, #stablehlo.bounds> + // CHECK: tensor<1x?xf32, #stablehlo.bounds> + return %1 : tensor<1x?xf32, #stablehlo.bounds> +} + +// ----- + +func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds> { + // expected-error@+2 {{elements literal type must have static shape}} + %c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds> + return %c : tensor<1x?xf32, #stablehlo.bounds> +} diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index 24e2398fe7..1f16f1a6d8 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -1940,6 +1940,17 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> { return %1 : tensor<12xi64> } +// ----- + +// CHECK-LABEL: @reorder_invalid_with_dynamic_shape +func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor) { + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> + // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor + // CHECK: return %[[CONVERT]] + %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> + %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor + return %1 : tensor +} // ----- diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index bf01394cda..3fede6e9eb 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -25,12 +25,17 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/Version.h" namespace mlir { namespace stablehlo { #define GEN_PASS_DECL + +std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass( + GreedyRewriteConfig config); + #define GEN_PASS_REGISTRATION #include "stablehlo/transforms/Passes.h.inc" diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index a7780855d9..ebb6c027a3 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -21,6 +22,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" @@ -38,6 +40,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -1447,12 +1450,18 @@ struct ReorderElementwiseAndShapeOp final return rewriter.notifyMatchFailure( op, "defining operation of unexpected type"); + // Reshape and broadcast are not allowed to have dynamic shape. + Value result = op->getResult(0); + if (isa(definingOp) && + !cast(result.getType()).hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "cannot reorder around reshape/broadcast with dynamic shape"); + // Only reorder if the defining op has no other uses. if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) return rewriter.notifyMatchFailure(op, "operation has more than one use"); Value input = definingOp->getOperand(0); - Value result = op->getResult(0); auto intermediateType = cast(input.getType()) .clone(getElementTypeOrSelf(result.getType())); @@ -1470,6 +1479,9 @@ struct ReorderElementwiseAndShapeOp final struct StablehloAggressiveSimplificationPass final : impl::StablehloAggressiveSimplificationPassBase< StablehloAggressiveSimplificationPass> { + StablehloAggressiveSimplificationPass() = default; + StablehloAggressiveSimplificationPass(GreedyRewriteConfig config) + : config(config) {} LogicalResult initialize(MLIRContext *context) override { RewritePatternSet patterns_(context); populateStablehloCanonicalizationPatterns(context, &patterns_); @@ -1478,11 +1490,12 @@ struct StablehloAggressiveSimplificationPass final } void runOnOperation() override { - if (failed(applyPatternsGreedily(getOperation(), patterns))) + if (failed(applyPatternsGreedily(getOperation(), patterns, config))) signalPassFailure(); } private: + GreedyRewriteConfig config; FrozenRewritePatternSet patterns; }; @@ -1515,5 +1528,10 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context); } +std::unique_ptr createStablehloAggressiveSimplificationPass( + GreedyRewriteConfig config) { + return std::make_unique(config); +} + } // namespace stablehlo } // namespace mlir