From b469cd0653ae667da4aa663a56df8eba3768ef29 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Tue, 14 Jan 2025 06:55:16 -0800 Subject: [PATCH 1/8] [SYCL][Matrix] Propagate constexpr matrix layout even with O0 Per SPIR-V specification Layout of a matrix must be a constant instruction aka a constexpr or specialization constant. Meanwhile in SYCL headers layout is passed as a parameter to joint_matrix_load function, so even if that layout is a constant expression in the user's code - it's not possible to prove that to the compiler, so constant propagation will happen only after inlining, not in AST. That means, that with O0 layout would remain to be a runtime variable in LLVM IR. SYCL matrix layout is being mapped on SPIR-V matrix layout by joint_matrix_layout_to_spv function. This patch adds routine that finds calls to this function and replaces them with the found constant. To help this routine always_inline attribute was removed from joint_matrix_layout_to_spv function. Signed-off-by: Sidorov, Dmitry --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 65 +++++++++++++++++- .../JointMatrixTransform/constexpr-layout.ll | 68 +++++++++++++++++++ .../oneapi/matrix/matrix-unified-utils.hpp | 2 +- 3 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index c5c03b2ae1c16..2f0166986c7bb 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -11,8 +11,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" +#include +#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" #include "llvm/IR/IRBuilder.h" using namespace llvm; @@ -21,6 +22,7 @@ namespace { static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain"; static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR"; +static constexpr char MATRIX_LAYOUT[] = "_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE"; Type *getInnermostType(Type *Ty) { while (auto *ArrayTy = dyn_cast(Ty)) @@ -184,12 +186,73 @@ bool transformAccessChain(Function *F) { } return ModuleChanged; } + +// Per SPIR-V specification Layout of a matrix must be a constant instruction +// aka a constexpr or specialization constant. Meanwhile in SYCL headers +// layout is passed as a parameter to joint_matrix_load function, so even if +// that layout is a constant expression in the user's code - it's not possible +// to prove that to the compiler, so constant propagation will happen only +// after inlining, not in AST. That means, that with O0 layout would remain +// to be a runtime variable in LLVM IR. +// SYCL matrix layout is being mapped on SPIR-V matrix layout by +// joint_matrix_layout_to_spv function. The following routine finds calls to +// this function and replaces them with the found constant. +// This function also cleans up code, that becomes dead. Pattern of the dead +// code is stable, as user's code doesn't affect it. +bool propagateConstexprLayout(Function *F) { + bool ModuleChanged = false; + std::queue ToErase; + for (auto I = F->user_begin(), E = F->user_end(); I != E;) { + User *U = *I++; + auto *CI = dyn_cast(U); + if (!CI) + continue; + auto *Op = dyn_cast(CI->getArgOperand(0)); + if (!Op || !isa(Op)) + continue; + auto *Ptr = dyn_cast(cast(Op)->getPointerOperand()); + if (!Ptr) + continue; + + ConstantInt *ConstLayout = nullptr; + for (const auto &U : Ptr->users()) { + if (!isa(U)) + continue; + assert(!ConstLayout && "More than 1 layout value was found"); + auto *SI = cast(U); + ConstLayout = dyn_cast(SI->getValueOperand()); + if (ConstLayout) { + CI->replaceAllUsesWith(ConstLayout); + ToErase.push(CI); + ToErase.push(SI); + ModuleChanged = true; + } + } + if (ModuleChanged) { + ToErase.push(Op); + ToErase.push(Ptr); + if (auto *Cast = dyn_cast(Ptr)) { + auto *OrigPtr = Cast->getPointerOperand(); + if (auto *AI = dyn_cast(OrigPtr)) + ToErase.push(AI); + } + } + while (!ToErase.empty()) { + ToErase.front()->dropAllReferences(); + ToErase.front()->eraseFromParent(); + ToErase.pop(); + } + } + return ModuleChanged; +} } // namespace PreservedAnalyses SYCLJointMatrixTransformPass::run(Module &M, ModuleAnalysisManager &MAM) { bool ModuleChanged = false; for (Function &F : M) { + if (F.getName() == MATRIX_LAYOUT) + ModuleChanged |= propagateConstexprLayout(&F); if (!F.isDeclaration()) continue; if (F.getName().starts_with(ACCESS_CHAIN)) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll new file mode 100644 index 0000000000000..98a37516d656c --- /dev/null +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -0,0 +1,68 @@ +; The test checks, that unused call to __spirv_AccessChain is eliminated. + +; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s + +; ModuleID = 'test.bc' +source_filename = "test.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +$_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE = comdat any + +; CHECK: define weak_odr dso_local spir_kernel void @test +; CHECK-NEXT: entry: +; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR +; CHECK-NEXT: ret void + +define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix, i64 noundef %stride) { +entry: + %layout = alloca i32, align 4 + %layout.ascast = addrspacecast ptr %layout to ptr addrspace(4) + store i32 0, ptr addrspace(4) %layout.ascast, align 4 + %layout.val = load i32, ptr addrspace(4) %layout.ascast, align 4 + %layout.spv = call spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %layout.val) + %mload = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix, i32 noundef %layout.spv, i64 noundef %stride, i32 noundef 0) + ret void +} + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef, i32 noundef, i64 noundef, i32 noundef) + +define linkonce_odr dso_local spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %Layout) comdat { +entry: + %retval = alloca i32, align 4 + %Layout.addr = alloca i32, align 4 + %retval.ascast = addrspacecast ptr %retval to ptr addrspace(4) + %Layout.addr.ascast = addrspacecast ptr %Layout.addr to ptr addrspace(4) + store i32 %Layout, ptr addrspace(4) %Layout.addr.ascast, align 4 + %0 = load i32, ptr addrspace(4) %Layout.addr.ascast, align 4 + switch i32 %0, label %sw.epilog [ + i32 0, label %sw.bb + i32 1, label %sw.bb1 + i32 2, label %sw.bb2 + i32 3, label %sw.bb3 + ] + +sw.bb: ; preds = %entry + store i32 0, ptr addrspace(4) %retval.ascast, align 4 + br label %return + +sw.bb1: ; preds = %entry + store i32 1, ptr addrspace(4) %retval.ascast, align 4 + br label %return + +sw.bb2: ; preds = %entry + store i32 2, ptr addrspace(4) %retval.ascast, align 4 + br label %return + +sw.bb3: ; preds = %entry + store i32 3, ptr addrspace(4) %retval.ascast, align 4 + br label %return + +sw.epilog: ; preds = %entry + call void @llvm.trap() + unreachable + +return: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb + %1 = load i32, ptr addrspace(4) %retval.ascast, align 4 + ret i32 %1 +} diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index ec14cf6da1931..d2143e6f6fb65 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -68,7 +68,7 @@ convertMatrixUseStringToEnum(const char *UseString) { return std::nullopt; } -inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv( +constexpr inline __spv::MatrixLayout joint_matrix_layout_to_spv( sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { case sycl::ext::oneapi::experimental::matrix::layout::row_major: From 2efb0821b2a42a6f9f4e5207cf9c9666cde3e6da Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Tue, 14 Jan 2025 07:08:21 -0800 Subject: [PATCH 2/8] format and test Signed-off-by: Sidorov, Dmitry --- llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp | 6 ++++-- .../SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 2f0166986c7bb..3cc097a57313b 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -13,8 +13,8 @@ #include -#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" using namespace llvm; @@ -22,7 +22,9 @@ namespace { static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain"; static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR"; -static constexpr char MATRIX_LAYOUT[] = "_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE"; +static constexpr char MATRIX_LAYOUT[] = + "_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_" + "3ext6oneapi12experimental6matrix6layoutE"; Type *getInnermostType(Type *Ty) { while (auto *ArrayTy = dyn_cast(Ty)) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll index 98a37516d656c..81ffa57f09710 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -11,7 +11,7 @@ $_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6m ; CHECK: define weak_odr dso_local spir_kernel void @test ; CHECK-NEXT: entry: -; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR +; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}} ; CHECK-NEXT: ret void define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix, i64 noundef %stride) { From a391ecb9c74cb3169f869302dfaedd93133d878a Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 15 Jan 2025 10:42:41 -0800 Subject: [PATCH 3/8] apply comments Signed-off-by: Sidorov, Dmitry --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 44 +++++++++++-------- .../JointMatrixTransform/constexpr-layout.ll | 2 + .../oneapi/matrix/matrix-unified-utils.hpp | 2 +- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 3cc097a57313b..3b4a20a9ec91d 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -11,10 +11,8 @@ // //===----------------------------------------------------------------------===// -#include - -#include "llvm/IR/IRBuilder.h" #include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" +#include "llvm/IR/IRBuilder.h" using namespace llvm; @@ -203,7 +201,7 @@ bool transformAccessChain(Function *F) { // code is stable, as user's code doesn't affect it. bool propagateConstexprLayout(Function *F) { bool ModuleChanged = false; - std::queue ToErase; + llvm::SmallVector ToErase; for (auto I = F->user_begin(), E = F->user_end(); I != E;) { User *U = *I++; auto *CI = dyn_cast(U); @@ -225,25 +223,25 @@ bool propagateConstexprLayout(Function *F) { ConstLayout = dyn_cast(SI->getValueOperand()); if (ConstLayout) { CI->replaceAllUsesWith(ConstLayout); - ToErase.push(CI); - ToErase.push(SI); + ToErase.push_back(CI); + ToErase.push_back(SI); ModuleChanged = true; } } if (ModuleChanged) { - ToErase.push(Op); - ToErase.push(Ptr); + ToErase.push_back(Op); + ToErase.push_back(Ptr); if (auto *Cast = dyn_cast(Ptr)) { auto *OrigPtr = Cast->getPointerOperand(); - if (auto *AI = dyn_cast(OrigPtr)) - ToErase.push(AI); + if (auto *AI = dyn_cast(OrigPtr)) { + ToErase.push_back(AI); + } } } - while (!ToErase.empty()) { - ToErase.front()->dropAllReferences(); - ToErase.front()->eraseFromParent(); - ToErase.pop(); - } + } + for (Instruction *II : ToErase) { + II->dropAllReferences(); + II->eraseFromParent(); } return ModuleChanged; } @@ -252,14 +250,22 @@ bool propagateConstexprLayout(Function *F) { PreservedAnalyses SYCLJointMatrixTransformPass::run(Module &M, ModuleAnalysisManager &MAM) { bool ModuleChanged = false; + llvm::SmallVector ToErase; for (Function &F : M) { - if (F.getName() == MATRIX_LAYOUT) - ModuleChanged |= propagateConstexprLayout(&F); - if (!F.isDeclaration()) - continue; + if (!F.isDeclaration()) { + if (F.getName() == MATRIX_LAYOUT) { + ModuleChanged |= propagateConstexprLayout(&F); + ToErase.push_back(&F); + } else + continue; + } if (F.getName().starts_with(ACCESS_CHAIN)) ModuleChanged |= transformAccessChain(&F); } + for (auto *F : ToErase) + if (F->users().empty()) + F->eraseFromParent(); + return ModuleChanged ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll index 81ffa57f09710..46205cd44d3e1 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -14,6 +14,8 @@ $_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6m ; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}} ; CHECK-NEXT: ret void +; CHECK-NOT: _ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE + define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix, i64 noundef %stride) { entry: %layout = alloca i32, align 4 diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index d2143e6f6fb65..af0dff14a0d4c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -68,7 +68,7 @@ convertMatrixUseStringToEnum(const char *UseString) { return std::nullopt; } -constexpr inline __spv::MatrixLayout joint_matrix_layout_to_spv( +constexpr __spv::MatrixLayout joint_matrix_layout_to_spv( sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { case sycl::ext::oneapi::experimental::matrix::layout::row_major: From 445789e9f900401889f7bb604000ed40283f60ec Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 15 Jan 2025 10:47:12 -0800 Subject: [PATCH 4/8] apply another comment Signed-off-by: Sidorov, Dmitry --- llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 3b4a20a9ec91d..7493bc0e3e5b9 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -200,7 +200,6 @@ bool transformAccessChain(Function *F) { // This function also cleans up code, that becomes dead. Pattern of the dead // code is stable, as user's code doesn't affect it. bool propagateConstexprLayout(Function *F) { - bool ModuleChanged = false; llvm::SmallVector ToErase; for (auto I = F->user_begin(), E = F->user_end(); I != E;) { User *U = *I++; @@ -225,10 +224,9 @@ bool propagateConstexprLayout(Function *F) { CI->replaceAllUsesWith(ConstLayout); ToErase.push_back(CI); ToErase.push_back(SI); - ModuleChanged = true; } } - if (ModuleChanged) { + if (!ToErase.empty()) { ToErase.push_back(Op); ToErase.push_back(Ptr); if (auto *Cast = dyn_cast(Ptr)) { @@ -243,7 +241,7 @@ bool propagateConstexprLayout(Function *F) { II->dropAllReferences(); II->eraseFromParent(); } - return ModuleChanged; + return !ToErase.empty(); } } // namespace From 2bc617ffb3aa1b83f8f4a18111bf4b2cef18b5e4 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 16 Jan 2025 06:23:14 -0800 Subject: [PATCH 5/8] wip Signed-off-by: Sidorov, Dmitry --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 8 ++--- .../JointMatrixTransform/constexpr-layout.ll | 34 +++++++++++++------ .../oneapi/matrix/matrix-unified-utils.hpp | 2 +- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 7493bc0e3e5b9..76de113104844 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -20,9 +20,7 @@ namespace { static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain"; static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR"; -static constexpr char MATRIX_LAYOUT[] = - "_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_" - "3ext6oneapi12experimental6matrix6layoutE"; +static constexpr char MATRIX_LAYOUT[] = "joint_matrix_layout_to_spv"; Type *getInnermostType(Type *Ty) { while (auto *ArrayTy = dyn_cast(Ty)) @@ -226,7 +224,7 @@ bool propagateConstexprLayout(Function *F) { ToErase.push_back(SI); } } - if (!ToErase.empty()) { + if (ConstLayout) { ToErase.push_back(Op); ToErase.push_back(Ptr); if (auto *Cast = dyn_cast(Ptr)) { @@ -238,6 +236,8 @@ bool propagateConstexprLayout(Function *F) { } } for (Instruction *II : ToErase) { + if (!II) + continue; II->dropAllReferences(); II->eraseFromParent(); } diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll index 46205cd44d3e1..8387f92e1445a 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -1,4 +1,5 @@ -; The test checks, that unused call to __spirv_AccessChain is eliminated. +; The test checks, that users of the call to joint_matrix_layout_to_spv matrix +; are replaced with the layout constant. ; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s @@ -7,29 +8,40 @@ source_filename = "test.cpp" target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" target triple = "spir64-unknown-unknown" -$_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE = comdat any +$joint_matrix_layout_to_spv = comdat any ; CHECK: define weak_odr dso_local spir_kernel void @test ; CHECK-NEXT: entry: ; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}} ; CHECK-NEXT: ret void -; CHECK-NOT: _ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE +; CHECK-NOT: joint_matrix_layout_to_spv -define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix, i64 noundef %stride) { +define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix.1, ptr addrspace(1) %matrix.2, i64 noundef %stride) { entry: - %layout = alloca i32, align 4 - %layout.ascast = addrspacecast ptr %layout to ptr addrspace(4) - store i32 0, ptr addrspace(4) %layout.ascast, align 4 - %layout.val = load i32, ptr addrspace(4) %layout.ascast, align 4 - %layout.spv = call spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %layout.val) - %mload = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix, i32 noundef %layout.spv, i64 noundef %stride, i32 noundef 0) + %layout.1 = alloca i32, align 4 + %layout.2 = alloca i32, align 4 + %layout.ascast.1 = addrspacecast ptr %layout.1 to ptr addrspace(4) + %layout.ascast.2 = addrspacecast ptr %layout.2 to ptr addrspace(4) + store i32 0, ptr addrspace(4) %layout.ascast.1, align 4 + store i32 1, ptr addrspace(4) %layout.ascast.2, align 4 + + %layout.val.1 = load i32, ptr addrspace(4) %layout.ascast.1, align 4 + %layout.spv.1 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.1) + %mload.1 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.1, i32 noundef %layout.spv.1, i64 noundef %stride, i32 noundef 0) + + %layout.val.2 = load i32, ptr addrspace(4) %layout.ascast.2, align 4 + %layout.spv.2 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2) + %mload.2 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.2, i64 noundef %stride, i32 noundef 0) + + %layout.spv.3 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2) + %mload.3 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.3, i64 noundef %stride, i32 noundef 0) ret void } declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef, i32 noundef, i64 noundef, i32 noundef) -define linkonce_odr dso_local spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %Layout) comdat { +define linkonce_odr dso_local spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %Layout) comdat { entry: %retval = alloca i32, align 4 %Layout.addr = alloca i32, align 4 diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index af0dff14a0d4c..9df493bcdc333 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -68,7 +68,7 @@ convertMatrixUseStringToEnum(const char *UseString) { return std::nullopt; } -constexpr __spv::MatrixLayout joint_matrix_layout_to_spv( +extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv( sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { case sycl::ext::oneapi::experimental::matrix::layout::row_major: From d34fe36abe2aa82779fdf21871be4dac1a73b157 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 16 Jan 2025 08:07:57 -0800 Subject: [PATCH 6/8] fix Signed-off-by: Sidorov, Dmitry --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 46 ++++++++++++------- .../JointMatrixTransform/constexpr-layout.ll | 5 ++ .../oneapi/matrix/matrix-unified-utils.hpp | 2 + 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 76de113104844..657adc3b8a61c 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -13,6 +13,7 @@ #include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" #include "llvm/IR/IRBuilder.h" +//#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -185,6 +186,17 @@ bool transformAccessChain(Function *F) { return ModuleChanged; } +StoreInst *findLastStoreBeforeLoad(Value *Ptr, Instruction *Load) { + BasicBlock::iterator It(Load); + while (It != Load->getParent()->begin()) { + --It; + if (auto *Store = dyn_cast(&*It)) + if (Store->getPointerOperand() == Ptr) + return Store; + } + return nullptr; +} + // Per SPIR-V specification Layout of a matrix must be a constant instruction // aka a constexpr or specialization constant. Meanwhile in SYCL headers // layout is passed as a parameter to joint_matrix_load function, so even if @@ -198,7 +210,7 @@ bool transformAccessChain(Function *F) { // This function also cleans up code, that becomes dead. Pattern of the dead // code is stable, as user's code doesn't affect it. bool propagateConstexprLayout(Function *F) { - llvm::SmallVector ToErase; + llvm::SmallVector ToErase; for (auto I = F->user_begin(), E = F->user_end(); I != E;) { User *U = *I++; auto *CI = dyn_cast(U); @@ -212,34 +224,36 @@ bool propagateConstexprLayout(Function *F) { continue; ConstantInt *ConstLayout = nullptr; - for (const auto &U : Ptr->users()) { - if (!isa(U)) - continue; - assert(!ConstLayout && "More than 1 layout value was found"); - auto *SI = cast(U); - ConstLayout = dyn_cast(SI->getValueOperand()); - if (ConstLayout) { - CI->replaceAllUsesWith(ConstLayout); - ToErase.push_back(CI); - ToErase.push_back(SI); - } - } + StoreInst *SI = findLastStoreBeforeLoad(Ptr, Op); + if (!SI) + continue; + ConstLayout = dyn_cast(SI->getValueOperand()); if (ConstLayout) { + CI->replaceAllUsesWith(ConstLayout); + ToErase.push_back(CI); + ToErase.push_back(SI); ToErase.push_back(Op); ToErase.push_back(Ptr); if (auto *Cast = dyn_cast(Ptr)) { auto *OrigPtr = Cast->getPointerOperand(); - if (auto *AI = dyn_cast(OrigPtr)) { + if (auto *AI = dyn_cast(OrigPtr)) ToErase.push_back(AI); - } } } } + + // There are possible cases, when a single instruction result is used multiple + // times. For this case we have to use a vector to store such instructions + // and keep track if we have removed them before to avoid double free(). + SmallPtrSet Erased; for (Instruction *II : ToErase) { - if (!II) + if (!II->use_empty()) + continue; + if (Erased.contains(II)) continue; II->dropAllReferences(); II->eraseFromParent(); + Erased.insert(II); } return !ToErase.empty(); } diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll index 8387f92e1445a..d4308aacbe338 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -36,6 +36,11 @@ entry: %layout.spv.3 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2) %mload.3 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.3, i64 noundef %stride, i32 noundef 0) + + store i32 2, ptr addrspace(4) %layout.ascast.2, align 4 + %layout.val.4 = load i32, ptr addrspace(4) %layout.ascast.2, align 4 + %layout.spv.4 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.4) + %mload.4 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.4, i64 noundef %stride, i32 noundef 0) ret void } diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index 9df493bcdc333..f0ddd647c62e9 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -68,6 +68,8 @@ convertMatrixUseStringToEnum(const char *UseString) { return std::nullopt; } +// propagateConstexprLayout uses the exact name of the function, so we use +// extern "C" here. extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv( sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { From 667cf10120023280d4bcdb6fd5baddc14cfd05a2 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 16 Jan 2025 08:19:28 -0800 Subject: [PATCH 7/8] and extend the test Signed-off-by: Sidorov, Dmitry --- llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll index d4308aacbe338..b2d352a809be9 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll @@ -13,6 +13,9 @@ $joint_matrix_layout_to_spv = comdat any ; CHECK: define weak_odr dso_local spir_kernel void @test ; CHECK-NEXT: entry: ; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}} +; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}} +; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}} +; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 2, i64 noundef{{.*}} ; CHECK-NEXT: ret void ; CHECK-NOT: joint_matrix_layout_to_spv From 7513d30f0ddd3d128aa1fdb518981de90a60cafd Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 16 Jan 2025 08:23:03 -0800 Subject: [PATCH 8/8] remote include Signed-off-by: Sidorov, Dmitry --- llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 657adc3b8a61c..1a39c994b1ede 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h" + #include "llvm/IR/IRBuilder.h" -//#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm;