Skip to content

Commit

Permalink
[CanonicalizeDoublyStridedOp] Fix for interleaved unit and linear dims (
Browse files Browse the repository at this point in the history
#564)

Fixes an issue exposed by a 128x32x64 matmul:
#556.

In the case of a strided pattern like:

```
offsets: [0, 0, 0, 0] 
sizes: [2, 1, 64, 64]
strides: [4096, 64, 64, 1]
```

the unit dimension (size == 1) in the middle will block the recognition
that this is a linear access pattern, resulting in the following
canonicalized strided pattern:

```
offsets: [0, 0] 
sizes: [2, 4096]
strides: [4096, 1]
```

If the unit dimension is first removed, the strided pattern can be
canonicalized further:

```
offsets: [] 
sizes: []
strides: []
```

meaning a complete linear access.

NOTE: with this fix the above matmul shape is still not functional, but
exhibits the same behaviour as `128x32x128` etc
  • Loading branch information
jtuyls authored Jul 17, 2024
1 parent 3b427f7 commit ba32d19
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,16 @@ void AMDAIECanonicalizeDoublyStridedOpPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

// Fold linear dimensions within a DMA op.
// Fold DMA unit dimensions. Needs to happen before folding linear dimensions
// to avoid blocking detection of linear dimension folding opportunities due
// to a unit dimension in between.
parentOp->walk([&](AMDAIE::DoublyStridedOpInterface dmaOp) {
(void)foldDmaOpLinearDims(rewriter, dmaOp);
(void)foldDmaOpUnitDims(rewriter, dmaOp);
});

// Fold DMA unit dimensions.
// Fold linear dimensions within a DMA op.
parentOp->walk([&](AMDAIE::DoublyStridedOpInterface dmaOp) {
(void)foldDmaOpUnitDims(rewriter, dmaOp);
(void)foldDmaOpLinearDims(rewriter, dmaOp);
});

// Make DMA accesses with single dimension implicit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ func.func @circular_dma_cpy_nd_unit(%arg0: !amdaie.logicalobjectfifo<memref<1x1x

// -----

// CHECK-LABEL: func.func @circular_dma_cpy_nd_unit_between_linear
// CHECK: amdaie.circular_dma_cpy_nd
// CHECK-SAME: [] [] []
// CHECK-SAME: [] [] []
func.func @circular_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo<memref<1x1x2x2x4x8xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[0, 0, 0, 0, 0, 0] [1, 2, 2, 4, 1, 8] [128, 64, 32, 8, 8, 1], %arg1[0, 0, 0, 0, 0, 0] [2, 2, 1, 4, 8, 1] [64, 32, 32, 8, 1, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x2x2x4x8xi32, 1>>, !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>)
amdaie.logicalobjectfifo.consume(%0)
return
}

// -----

// CHECK-LABEL: func.func @circular_dma_cpy_nd_non_zero_offset
// CHECK: amdaie.circular_dma_cpy_nd
// CHECK-SAME: [1, 1, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1]
Expand Down Expand Up @@ -179,6 +191,18 @@ func.func @dma_cpy_nd_unit(%arg0: !amdaie.logicalobjectfifo<memref<1x1x2x2x4x8xi

// -----

// CHECK-LABEL: func.func @dma_cpy_nd_unit_between_linear
// CHECK: amdaie.dma_cpy_nd
// CHECK-SAME: [] [] []
// CHECK-SAME: [] [] []
func.func @dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo<memref<1x1x2x2x4x8xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>) {
%0 = amdaie.dma_cpy_nd(%arg0[0, 0, 0, 0, 0, 0] [2, 2, 1, 1, 4, 8] [64, 32, 32, 32, 8, 1], %arg1[0, 0, 0, 0, 0, 0] [2, 1, 2, 1, 4, 8] [64, 64, 32, 32, 8, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x2x2x4x8xi32, 1>>, !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>)
amdaie.logicalobjectfifo.consume(%0)
return
}

// -----

// CHECK-LABEL: func.func @dma_cpy_nd_non_zero_offset
// CHECK: amdaie.dma_cpy_nd
// CHECK-SAME: [1, 1, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1]
Expand Down Expand Up @@ -282,6 +306,18 @@ func.func @npu_dma_cpy_nd_unit(%arg0: !amdaie.logicalobjectfifo<memref<1x1x2x2x4

// -----

// CHECK-LABEL: func.func @npu_dma_cpy_nd_unit_between_linear
// CHECK: amdaie.npu.dma_cpy_nd
// CHECK-SAME: [] [] []
// CHECK-SAME: [] [] []
func.func @npu_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[] [] [], %arg1[] [] []) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
%1 = amdaie.npu.dma_cpy_nd %0([0, 0, 0, 0] [2, 1, 64, 64] [4096, 64, 64, 1], [0, 0, 0, 0] [2, 1, 1, 64] [64, 64, 64, 1])
return
}

// -----

// CHECK-LABEL: func.func @npu_dma_cpy_nd_non_zero_offset
// CHECK: amdaie.npu.dma_cpy_nd
// CHECK-SAME: [1, 1, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1]
Expand Down

0 comments on commit ba32d19

Please sign in to comment.