Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate LLVM at llvm/llvm-project@aa65f93b71de #2701

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d"
LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"

LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f"
LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e2402615a5a76d46a433dfcc1de10b38a1263c9d
aa65f93b71dee8cacb22be1957673c8be6a3ec24
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,22 @@ bool isValidQuantizedDimension(Type type) {
numScales == rankedType.getDimSize(quantDim));
}

bool hasSingleBoundedDimension(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
auto boundedAttr =
dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
if (!boundedAttr) return false;

// Count if bounded attr size is not kDynamic
int64_t numBoundedDims = llvm::count_if(
boundedAttr.getBounds(),
[](int64_t bound) { return !ShapedType::isDynamic(bound); });
// Also check that there are only bounded dims and no unbounded dims.
int64_t numDynamicDims = llvm::count_if(
rankedType.getShape(),
[](int64_t bound) { return ShapedType::isDynamic(bound); });
return numBoundedDims == 1 && numDynamicDims == 1;
}

} // namespace hlo
} // namespace mlir
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType);
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

// Returns true if the given type has a single bounded dimension.
bool hasSingleBoundedDimension(Type type);

// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;

//===----------------------------------------------------------------------===//
// HLO type constraints.
//===----------------------------------------------------------------------===//

// Note: Bounded dynamisms is largely unspecced and this feature needs more
// thoguht as it is adopted to modern frameworks. The current support is
// designed to allow existing TF programs to be representable in StableHLO and
// is subject to change as a formal design for boudned dynamism is developed.
def HLO_HasSingleBoundedDimensionPred
: CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;

def HLO_HasStaticOrSingleBoundedShapePred
: Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -267,6 +281,9 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">;

def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,7 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
);

let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);

let hasVerifier = 1;

Expand Down Expand Up @@ -2732,7 +2732,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",

let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand);

let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3724,9 +3724,8 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
Value operand,
ArrayRef<int64_t> broadcastDimensions,
Value result) {
auto operandType = cast<RankedTensorType>(operand.getType());

// broadcast_in_dim_c1
auto operandType = cast<RankedTensorType>(operand.getType());
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType,
result.getType())))
return failure();
Expand Down Expand Up @@ -4658,11 +4657,12 @@ LogicalResult verifyReshapeOp(std::optional<Location> location, Value operand,
Value result) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operandTy = cast<RankedTensorType>(operand.getType());
if (!operandTy.hasStaticShape()) return success();
auto resultTy = cast<RankedTensorType>(result.getType());
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return success();

// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto resultTy = cast<RankedTensorType>(result.getType());
int64_t numResultElements = resultTy.getNumElements();
int64_t numOperandElements = operandTy.getNumElements();
if (numResultElements != numOperandElements)
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ FailureOr<int64_t> Version::getBytecodeVersion() const {
Version Version::fromCompatibilityRequirement(
CompatibilityRequirement requirement) {
// Compatibility requirement versions can be updated as needed, as long as the
// version satisifies the requirement.
// version satisfies the requirement.
// The time frames used are from the date that the release was tagged on, not
// merged. The tag date is when the version has been verified and exported to
// XLA. See: https://github.com/openxla/stablehlo/tags
Expand Down
17 changes: 9 additions & 8 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -40,7 +40,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "stablehlo/dialect/AssemblyFormat.h"
#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: keep
#include "stablehlo/dialect/Version.h"
#include "stablehlo/dialect/VhloBytecode.h"
#include "stablehlo/dialect/VhloTypes.h"
Expand Down Expand Up @@ -184,12 +184,13 @@ ParseResult parseFunctionBody(OpAsmParser& parser, Attribute& name,
return success();
}

void TensorV1Attr::print(mlir::AsmPrinter& p) const {
p << '<'
<< DenseIntOrFPElementsAttr::getFromRawBuffer(
llvm::cast<ShapedType>(convertTypeToBuiltinForPrint(getType())),
getData())
<< '>';
void TensorV1Attr::print(mlir::AsmPrinter& odsPrinter) const {
odsPrinter << '<'
<< DenseIntOrFPElementsAttr::getFromRawBuffer(
llvm::cast<ShapedType>(
convertTypeToBuiltinForPrint(getType())),
getData())
<< '>';
}

// Parse tensor elements using DenseIntOrFPElementsAttr printing.
Expand Down
13 changes: 6 additions & 7 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() ||
type.isFloat6E3M2FN() || type.isFloat8E3M4() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() ||
type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
return llvm::isa<
mlir::Float4E2M1FNType, mlir::Float6E2M3FNType, mlir::Float6E3M2FNType,
mlir::Float8E3M4Type, mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3Type,
mlir::Float8E4M3FNType, mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
mlir::Float8E5M2FNUZType, mlir::Float8E8M0FNUType, mlir::Float16Type,
mlir::BFloat16Type, mlir::Float32Type, mlir::Float64Type>(type);
}

bool isSupportedComplexType(Type type) {
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,22 @@ func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {

// -----

// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1
func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor<?xi32>) -> tensor<1x3xi32> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<?xi32>) -> tensor<1x3xi32>
return %0 : tensor<1x3xi32>
}

// -----

func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor<?x3xi32> {
// expected-error@+1 {{must be statically shaped or single bounded dimension tensor}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>) -> tensor<?x3xi32>
func.return %0 : tensor<?x3xi32>
}

// -----

// Regression test for b/180052624, where this was improperly marked as an
// invalid stablehlo.broadcast_in_dim op.
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
Expand Down
63 changes: 63 additions & 0 deletions stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file captures some quirks to bounded dynamism in StableHLO that are
// included to allow StableHLO to repersent existing TF programs.

// CHECK-LABEL: reshape_with_single_bounded_dimension
func.func @reshape_with_single_bounded_dimension(%arg0: tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.reshape %arg0 : (tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, 5>
return %0 : tensor<2x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension
func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.reshape %arg0 : (tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, 5>
return %0 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>> {
// expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
%0 = stablehlo.reshape %arg0 : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>>
return %0 : tensor<?x?xf32, #stablehlo.bounds<5, 5>>
}

// -----

// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension
func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, ?, 5>
return %0 : tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
}

// -----

func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>> {
// expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
return %0 : tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
}

// -----

// CHECK-LABEL: constant_splat_broadcast
func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: tensor<1x?xf32, #stablehlo.bounds<?, 5>>
return %1 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
// expected-error@+2 {{elements literal type must have static shape}}
%c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
return %c : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,17 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> {
return %1 : tensor<12xi64>
}

// -----

// CHECK-LABEL: @reorder_invalid_with_dynamic_shape
func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor<?x4xf32>) {
// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
// CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor<?x4xf32>
// CHECK: return %[[CONVERT]]
%0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
%1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor<?x4xf32>
return %1 : tensor<?x4xf32>
}

// -----

Expand Down
5 changes: 5 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ limitations under the License.
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {

#define GEN_PASS_DECL

std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config);

#define GEN_PASS_REGISTRATION
#include "stablehlo/transforms/Passes.h.inc"

Expand Down
22 changes: 20 additions & 2 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>

Expand All @@ -21,6 +22,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
Expand All @@ -38,6 +40,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -1447,12 +1450,18 @@ struct ReorderElementwiseAndShapeOp final
return rewriter.notifyMatchFailure(
op, "defining operation of unexpected type");

// Reshape and broadcast are not allowed to have dynamic shape.
Value result = op->getResult(0);
if (isa<ReshapeOp, BroadcastOp>(definingOp) &&
!cast<ShapedType>(result.getType()).hasStaticShape())
return rewriter.notifyMatchFailure(
op, "cannot reorder around reshape/broadcast with dynamic shape");

// Only reorder if the defining op has no other uses.
if (!llvm::hasSingleElement(definingOp->getResult(0).getUses()))
return rewriter.notifyMatchFailure(op, "operation has more than one use");

Value input = definingOp->getOperand(0);
Value result = op->getResult(0);
auto intermediateType = cast<ShapedType>(input.getType())
.clone(getElementTypeOrSelf(result.getType()));

Expand All @@ -1470,6 +1479,9 @@ struct ReorderElementwiseAndShapeOp final
struct StablehloAggressiveSimplificationPass final
: impl::StablehloAggressiveSimplificationPassBase<
StablehloAggressiveSimplificationPass> {
StablehloAggressiveSimplificationPass() = default;
StablehloAggressiveSimplificationPass(GreedyRewriteConfig config)
: config(config) {}
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet patterns_(context);
populateStablehloCanonicalizationPatterns(context, &patterns_);
Expand All @@ -1478,11 +1490,12 @@ struct StablehloAggressiveSimplificationPass final
}

void runOnOperation() override {
if (failed(applyPatternsGreedily(getOperation(), patterns)))
if (failed(applyPatternsGreedily(getOperation(), patterns, config)))
signalPassFailure();
}

private:
GreedyRewriteConfig config;
FrozenRewritePatternSet patterns;
};

Expand Down Expand Up @@ -1515,5 +1528,10 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config) {
return std::make_unique<StablehloAggressiveSimplificationPass>(config);
}

} // namespace stablehlo
} // namespace mlir