Skip to content

Commit dd1a80d

Browse files
authored
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect (#201)
This PR updates linalg-to-xegpu pass to make it compatible with xegpu-to-vc-func pass from IMEX. The PR also adds a simple e2e test for linalg->xegpu->gpu exe pipeline. --------- Signed-off-by: dchigarev <[email protected]>
1 parent a58150d commit dd1a80d

File tree

7 files changed

+161
-23
lines changed

7 files changed

+161
-23
lines changed

cmake/imex.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ if (NOT DEFINED IMEX_INCLUDES)
88

99
# TODO: Change to main https://github.com/intel/mlir-extensions when all the
1010
# required functionality is merged.
11-
gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions
11+
gc_fetch_content(imex d5bbd635dee500b8cff138686833bacfac5ade78 https://github.com/Menooker/mlir-extensions
1212
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
1313
)
1414

include/gc/Transforms/Passes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
4747
"DPAS register block sizes MxNxK">,
4848
];
4949
}
50-
#endif
50+
#endif // GC_USE_IMEX
5151

5252
def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
5353
"func::FuncOp"> {

lib/gc/ExecutionEngine/Driver/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ else()
2727
endif()
2828

2929
set(GC_PASSES GcInterface GcPasses)
30-
if(GC_UNABLE_GPU)
30+
if(GC_ENABLE_IMEX)
3131
list(APPEND GC_PASSES GcGpuPasses)
3232
endif()
3333

lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp

+29-5
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,35 @@ template <typename T> size_t countUntil(T *ptr, T &&elem) {
129129
} // namespace
130130

131131
static cl_device_id getDevice(cl_device_type *devtype) {
132-
cl_platform_id platform; // OpenCL platform
133-
cl_device_id device; // device ID
134-
CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL));
135-
CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL));
136-
return device;
132+
cl_uint numPlatforms;
133+
CL_SAFE_CALL(clGetPlatformIDs(0, nullptr, &numPlatforms)) // get num platforms
134+
135+
std::vector<cl_platform_id> platforms(numPlatforms);
136+
CL_SAFE_CALL(clGetPlatformIDs(numPlatforms, platforms.data(),
137+
nullptr)); // get available platforms
138+
139+
for (cl_uint i = 0; i < numPlatforms; ++i) {
140+
// Get GPU device IDs for each platform
141+
cl_uint numDevices;
142+
cl_int status =
143+
clGetDeviceIDs(platforms[i], *devtype, 0, /*devices.data()=*/nullptr,
144+
&numDevices); // get num devices with 'devtype'
145+
if (status != CL_SUCCESS) {
146+
if (status == CL_DEVICE_NOT_FOUND) {
147+
continue; // No GPU devices found on this platform
148+
}
149+
fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__,
150+
"Error getting device IDs");
151+
abort();
152+
}
153+
154+
std::vector<cl_device_id> devices(numDevices);
155+
clGetDeviceIDs(platforms[i], *devtype, numDevices, devices.data(), nullptr);
156+
return devices[0];
157+
}
158+
159+
fprintf(stderr, "No suitable devices found.");
160+
abort();
137161
}
138162

139163
struct GPUCLQUEUE {

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
597597
Location loc, ValueRange tiles,
598598
ArrayRef<int64_t> offsets) {
599599
SmallVector<Value> updatedTiles;
600+
// convert static offsets to dynamic because of this IMEX bug:
601+
// https://github.com/intel/mlir-extensions/issues/815
602+
std::vector<Value> dynOffsets;
603+
for (auto &x : offsets) {
604+
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
605+
dynOffsets.push_back(offset);
606+
}
607+
ValueRange newOffsets{dynOffsets};
600608
for (auto tile : tiles) {
601-
auto updatedTile =
602-
rewriter
603-
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
604-
/*offsets=*/ValueRange{}, offsets)
605-
.getResult();
609+
auto updatedTile = rewriter
610+
.create<xegpu::UpdateNdOffsetOp>(
611+
loc, tile.getType(), tile,
612+
/*offsets=*/newOffsets,
613+
SmallVector<int64_t>{ShapedType::kDynamic,
614+
ShapedType::kDynamic})
615+
.getResult();
606616
updatedTiles.push_back(updatedTile);
607617
}
608618

@@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
648658

649659
SmallVector<Value> tiles;
650660
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
661+
// convert static offsets to dynamic because of this IMEX bug:
662+
// https://github.com/intel/mlir-extensions/issues/815
663+
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
651664
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
665+
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
652666
auto tile = rewriter
653667
.create<xegpu::UpdateNdOffsetOp>(
654668
loc, descType, rootTile,
655-
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
669+
/*offsets=*/ValueRange{newRowOffs, newColOffs},
670+
SmallVector<int64_t>{ShapedType::kDynamic,
671+
ShapedType::kDynamic})
656672
.getResult();
657673
tiles.push_back(tile);
658674
}
@@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
732748

733749
VectorType vecLoadType =
734750
VectorType::get(tileType.getShape(), tileType.getElementType());
735-
UnitAttr vnniAxisAttr = nullptr;
751+
mlir::UnitAttr packedAttr = nullptr;
736752
if (vnniConf) {
737-
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
738753
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
739754
*vnniConf);
755+
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
740756
}
741-
757+
IntegerAttr transpose_bit = nullptr;
742758
SmallVector<Value> loadVec;
743759
for (auto tile : loadTiles) {
760+
744761
auto loadOp = rewriter.create<xegpu::LoadNdOp>(
745-
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
762+
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
746763
/*l1_hint=*/hint,
747764
/*l2_hint=*/hint, /*l3_hint=*/hint);
748765
loadVec.push_back(loadOp);
@@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10571074

10581075
// Load A sub-tiles.
10591076
SmallVector<Value> loadVecA =
1060-
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
1077+
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
10611078
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());
10621079

10631080
// Load B sub-tiles.

test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem
1818

1919
// Create output initial value load tiles.
2020
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
21-
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0]
21+
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
2222
// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]]
2323

2424
// Load initial accumulator values.
@@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem
3131

3232
// Create input load tiles.
3333
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
34-
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0]
34+
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
3535
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
36-
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0]
36+
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
3737
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]]
3838

3939
// Create DPAS computation loop over tiled reduction dimension.
@@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem
6363

6464
// Extract DPAS-sized chunks from larger loaded tile A.
6565
// Tile B is already in the correct shape.
66-
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16>
66+
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
6767
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
6868
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
6969
// CHECK-COUNT-3: vector.extract_strided_slice

0 commit comments

Comments
 (0)