Skip to content

Commit

Permalink
Enable support for broadcasting to a 2D array of aie cores (Xilinx#633)
Browse files Browse the repository at this point in the history
* Enable support for broadcasting to a 2D array of aie cores

* Remove assertion to allow 2d broadcasting to lower to mlir-aie dialect

* Switch the convolution board test to show 2d herd scaling
  • Loading branch information
erwei-xilinx authored Jul 1, 2024
1 parent f198477 commit 07c69a4
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 60 deletions.
1 change: 0 additions & 1 deletion mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,6 @@ void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp) {
indices_uint[dim] = 0;
}
}
assert(indices_uint[0] != 1 || indices_uint[1] != 1);
if (areIdenticalVectors(indices_uint, position)) {
alloc_id = iter;
}
Expand Down
14 changes: 0 additions & 14 deletions mlir/lib/Dialect/AIR/IR/AIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,20 +1118,6 @@ LogicalResult ChannelOp::verify() {
auto broadcast_shape = getBroadcastShape();
if (bundle_size.size() != broadcast_shape.size())
return emitOpError("bundle size should match broadcast_shape size");
int diffDims = 0;
int broadcastDim = -1;
for (int i = 0; i < (int)bundle_size.size(); i++)
if (dyn_cast<IntegerAttr>(bundle_size[i]).getInt() !=
dyn_cast<IntegerAttr>(broadcast_shape[i]).getInt()) {
diffDims++;
broadcastDim = i;
}
if (diffDims > 1)
return emitOpError("bundle sizes and broadcast_shape should only differ "
"along one dimension");
if (dyn_cast<IntegerAttr>(bundle_size[broadcastDim]).getInt() != 1)
return emitOpError("along the broadcast dimension the index in the "
"channel bundle sizes should be equal to 1");
}
return success();
}
Expand Down
18 changes: 16 additions & 2 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2326,8 +2326,8 @@ struct BroadcastDetection {
mlir::IntegerSetAttr::get(int_set));
}
} else if (!isVariantWrtHerdRows && !isVariantWrtHerdCols) {
// If a dma op is independent of herd induction vars, then we only
// broadcast it along either rows or columns, not both.
// If a dma op is independent of herd induction vars, then we broadcast
// it to every core in the herd.
if (numRows > 1 && numCols == 1) {
SmallVector<AffineExpr, 5> constraints{
getAffineDimExpr(0, ctx), numRows - 1 - getAffineDimExpr(0, ctx),
Expand All @@ -2348,6 +2348,20 @@ struct BroadcastDetection {
auto int_set = IntegerSet::get(2, 1, constraints, eqflags);
dma_op->setAttr("broadcast_pattern",
mlir::IntegerSetAttr::get(int_set));
} else {
// Broadcast to a 2d array of cores
SmallVector<AffineExpr, 6> constraints{
getAffineDimExpr(0, ctx),
numRows - 1 - getAffineDimExpr(0, ctx),
getAffineDimExpr(1, ctx),
numCols - 1 - getAffineDimExpr(1, ctx),
getAffineSymbolExpr(0, ctx),
-getAffineSymbolExpr(0, ctx)};
SmallVector<bool, 5> eqflags{false, false, false,
false, false, false};
auto int_set = IntegerSet::get(2, 1, constraints, eqflags);
dma_op->setAttr("broadcast_pattern",
mlir::IntegerSetAttr::get(int_set));
}
}
}
Expand Down
84 changes: 83 additions & 1 deletion mlir/test/Conversion/ConvertToAIR/broadcast_to_channel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
//
//===----------------------------------------------------------------------===//

// RUN: air-opt %s -air-dma-to-channel -canonicalize -cse | FileCheck %s
// RUN: air-opt %s -air-dma-to-channel -canonicalize -cse --split-input-file | FileCheck %s

#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 32)>
#set = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 + 1 >= 0)>
module {
// CHECK: air.channel @channel_0 [2, 1] {broadcast_shape = [2, 2]}
// CHECK-LABEL: @mmult
func.func @mmult(%arg0: memref<512x512xbf16>) {
%c8 = arith.constant 8 : index
%0 = air.launch async (%arg1, %arg2) in (%arg3=%c8, %arg4=%c8) args(%arg5=%arg0) : memref<512x512xbf16> attributes {id = 3 : i32} {
Expand Down Expand Up @@ -76,3 +77,84 @@ module {
}
}

// -----

// Broadcast to a 2D array of cores.

// CHECK-LABEL: func.func @conv
// CHECK: air.launch
// CHECK: air.segment
// CHECK-NOT: scf.parallel
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: air.channel.put{{.*}}@channel_0[]
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: air.herd
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: affine.if
// CHECK-NEXT: air.channel.get{{.*}}@channel_0
// CHECK-NEXT: affine.yield
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield

#set = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 >= 0, -s1 + 3 >= 0)>
module {
func.func @conv() {
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
%0 = air.launch async (%arg0, %arg1, %arg2) in (%arg3=%c3, %arg4=%c3, %arg5=%c16) attributes {id = 3 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%async_token, %results = air.execute -> (memref<3x3x32x4xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<3x3x32x4xi32, 1 : i32>
air.execute_terminator %alloc : memref<3x3x32x4xi32, 1 : i32>
} {id = 1 : i32}
%2 = air.herd @herd_0 async [%async_token] tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c4) args(%arg10=%results) : memref<3x3x32x4xi32, 1 : i32> attributes {id = 1 : i32} {
%c128 = arith.constant 128 : index
%c384 = arith.constant 384 : index
%c4_1 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c3_2 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%3 = air.wait_all async {id = 6 : i32}
%4 = scf.for %arg11 = %c0 to %c3_2 step %c1 iter_args(%arg12 = %3) -> (!air.async.token) {
%5 = scf.for %arg13 = %c0 to %c3_2 step %c1 iter_args(%arg14 = %arg12) -> (!air.async.token) {
%7 = scf.for %arg15 = %c0 to %c32 step %c8 iter_args(%arg16 = %arg14) -> (!air.async.token) {
%async_token_3, %results_4 = air.execute -> (memref<1x1x8x4xi32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x4xi32, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x4xi32, 2 : i32>
} {id = 2 : i32}
%9 = affine.if #set()[%arg6, %arg7] -> !air.async.token {
%11 = air.dma_memcpy_nd async [%arg16, %async_token_3] (%results_4[] [] [], %arg10[%arg11, %arg13, %arg15, %c0] [%c1, %c1, %c8, %c4_1] [%c384, %c128, %c4_1, %c1]) {broadcast_set = #set, id = 1 : i32} : (memref<1x1x8x4xi32, 2 : i32>, memref<3x3x32x4xi32, 1 : i32>)
affine.yield %11 : !air.async.token
} else {
%11 = air.wait_all async [%arg16, %async_token_3]
affine.yield %11 : !air.async.token
}
%10 = air.wait_all async [%arg16, %9] {id = 1 : i32}
scf.yield %10 : !air.async.token
}
%8 = air.wait_all async [%arg14, %7] {id = 3 : i32}
scf.yield %8 : !air.async.token
}
%6 = air.wait_all async [%arg12, %5] {id = 5 : i32}
scf.yield %6 : !air.async.token
}
}
%async_token_0 = air.execute [%2] {
memref.dealloc %results : memref<3x3x32x4xi32, 1 : i32>
} {id = 3 : i32}
}
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,45 @@ module {
return
}
}

// -----

// 2D broadcasting to all cores in a herd.

// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 1 >= 0, d1 >= 0, -d1 + 3 >= 0, s0 >= 0, -s0 >= 0)>
// CHECK-LABEL: func.func @func3
// CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}} : (memref<1x1x8x4xi32, 2 : i32>, memref<3x3x32x4xi32, 1 : i32>)

module {
func.func @func3() {
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
air.launch (%arg3, %arg4, %arg5) in (%arg6=%c3, %arg7=%c3, %arg8=%c16) {
air.segment @segment_0 {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%alloc = memref.alloc() : memref<3x3x32x4xi32, 1 : i32>
air.herd @herd_0 tile (%arg9, %arg10) in (%arg11=%c2, %arg12=%c4) args(%arg13=%alloc) : memref<3x3x32x4xi32, 1 : i32> {
%c128 = arith.constant 128 : index
%c384 = arith.constant 384 : index
%c4_0 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c3_1 = arith.constant 3 : index
%c1 = arith.constant 1 : index
scf.for %arg14 = %c0 to %c3_1 step %c1 {
scf.for %arg15 = %c0 to %c3_1 step %c1 {
scf.for %arg16 = %c0 to %c32 step %c8 {
%alloc_2 = memref.alloc() : memref<1x1x8x4xi32, 2 : i32>
air.dma_memcpy_nd (%alloc_2[] [] [], %arg13[%arg14, %arg15, %arg16, %c0] [%c1, %c1, %c8, %c4_0] [%c384, %c128, %c4_0, %c1]) : (memref<1x1x8x4xi32, 2 : i32>, memref<3x3x32x4xi32, 1 : i32>)
}
}
}
}
memref.dealloc %alloc : memref<3x3x32x4xi32, 1 : i32>
}
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,63 @@ module {
return
}
}

// -----

// DMA broadcast to a 2D array of cores.

// CHECK: [[$SET0:#set[0-9]*]] = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 >= 0, -s1 + 3 >= 0)>
// CHECK-LABEL: @func4
// CHECK: air.herd
// CHECK: %[[EVENT0:.*]] = affine.if [[$SET0]]
// CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_set = [[$SET0]]{{.*}}

#set = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 1 >= 0, d1 >= 0, -d1 + 3 >= 0, s0 >= 0, -s0 >= 0)>
module {
func.func @func4() {
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
%0 = air.launch async (%arg0, %arg1, %arg2) in (%arg3=%c3, %arg4=%c3, %arg5=%c16) attributes {id = 3 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%async_token, %results = air.execute -> (memref<3x3x32x4xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<3x3x32x4xi32, 1 : i32>
air.execute_terminator %alloc : memref<3x3x32x4xi32, 1 : i32>
} {id = 1 : i32}
%2 = air.herd @herd_0 async [%async_token] tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c4) args(%arg10=%results) : memref<3x3x32x4xi32, 1 : i32> attributes {id = 1 : i32} {
%c128 = arith.constant 128 : index
%c384 = arith.constant 384 : index
%c4_1 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c3_2 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%3 = air.wait_all async {id = 6 : i32}
%4 = scf.for %arg11 = %c0 to %c3_2 step %c1 iter_args(%arg12 = %3) -> (!air.async.token) {
%5 = scf.for %arg13 = %c0 to %c3_2 step %c1 iter_args(%arg14 = %arg12) -> (!air.async.token) {
%7 = scf.for %arg15 = %c0 to %c32 step %c8 iter_args(%arg16 = %arg14) -> (!air.async.token) {
%async_token_3, %results_4 = air.execute -> (memref<1x1x8x4xi32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x4xi32, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x4xi32, 2 : i32>
} {id = 2 : i32}
%9 = air.dma_memcpy_nd async [%arg16, %async_token_3] (%results_4[] [] [], %arg10[%arg11, %arg13, %arg15, %c0] [%c1, %c1, %c8, %c4_1] [%c384, %c128, %c4_1, %c1]) {broadcast_pattern = #set, id = 1 : i32} : (memref<1x1x8x4xi32, 2 : i32>, memref<3x3x32x4xi32, 1 : i32>)
%10 = air.wait_all async [%arg16, %9] {id = 1 : i32}
scf.yield %10 : !air.async.token
}
%8 = air.wait_all async [%arg14, %7] {id = 3 : i32}
scf.yield %8 : !air.async.token
}
%6 = air.wait_all async [%arg12, %5] {id = 5 : i32}
scf.yield %6 : !air.async.token
}
}
%async_token_0 = air.execute [%2] {
memref.dealloc %results : memref<3x3x32x4xi32, 1 : i32>
} {id = 3 : i32}
}
}
return
}
}
Loading

0 comments on commit 07c69a4

Please sign in to comment.