Skip to content

Commit

Permalink
[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (
Browse files Browse the repository at this point in the history
llvm#112394)

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
  • Loading branch information
banach-space authored Oct 26, 2024
1 parent 4102625 commit 0cf7aaf
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 18 deletions.
67 changes: 50 additions & 17 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,13 @@ namespace {
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Expand All @@ -376,6 +378,10 @@ class TransferReadDropUnitDimsPattern
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
Expand Down Expand Up @@ -406,27 +412,37 @@ class TransferReadDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));

if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newTransferReadOp = mlir::vector::maskOperation(
rewriter, newTransferReadOp, shapeCastMask);
}

auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp);
rewriter.replaceOp(transferReadOp, shapeCast);
loc, vectorType, newTransferReadOp->getResults()[0]);

return success();
return shapeCast;
}
};

/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
: public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Expand All @@ -444,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced destination shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
Expand Down Expand Up @@ -474,13 +494,26 @@ class TransferWriteDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
maskOp, rewriter.getBoolArrayAttr(inBounds));

if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newXferWrite =
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
}

return success();
if (transferWriteOp.hasPureTensorSemantics())
return newXferWrite->getResults()[0];

// With Memref semantics, there's no return value. Use empty value to signal
// success.
return Value();
}
};

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->

// -----

func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
%0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
return %0 : vector<f32>
}

// -----

// expected-note@+1 {{prior use here}}
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
%ft0 = arith.constant 0.0 : f32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s

//-----------------------------------------------------------------------------
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
//-----------------------------------------------------------------------------

func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
Expand All @@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]

func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
func.func @transfer_read_rank_reducing_masked(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%mask: vector<3x2xi1>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
} : vector<3x2xi1> -> vector<3x2xi8>
return %v : vector<3x2xi8>
}
// CHECK-LABEL: func @transfer_read_rank_reducing_masked
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.mask %[[MASK]]
// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]

func.func @transfer_write_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%vec : vector<3x2xi8>) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
Expand All @@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]

func.func @transfer_write_rank_reducing_masked(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%vec : vector<3x2xi8>,
%mask: vector<3x2xi1>) {
%c0 = arith.constant 0 : index
vector.mask %mask {
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
} : vector<3x2xi1>
return
}
// CHECK-LABEL: func @transfer_write_rank_reducing_masked
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.mask %[[MASK]]
// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]

func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>

func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
%arg : memref<1x1x1x1x1xf32>,
%mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {

%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
: memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
} : vector<1x1x1xi1> -> vector<1x1x1xf32>
return %v : vector<1x1x1xf32>
}
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
// CHECK-NOT: vector.shape_cast
// CHECK-NOT: memref.subview

func.func @transfer_write_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>) {
Expand All @@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>

func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>,
%mask: vector<1x1x1xi1>) {

%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
vector.mask %mask {
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
} : vector<1x1x1xi1>
return
}
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
// CHECK-NOT: vector.shape_cast
// CHECK-NOT: memref.subview

func.func @transfer_read_dynamic_rank_reducing(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
%c0 = arith.constant 0 : index
Expand Down

0 comments on commit 0cf7aaf

Please sign in to comment.