Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Matrix] Propagate constexpr matrix layout even with O0 #16628

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h"
#include <queue>
MrSidims marked this conversation as resolved.
Show resolved Hide resolved

#include "llvm/IR/IRBuilder.h"
#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h"
MrSidims marked this conversation as resolved.
Show resolved Hide resolved

using namespace llvm;

namespace {

static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
static constexpr char MATRIX_LAYOUT[] =
"_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_"
"3ext6oneapi12experimental6matrix6layoutE";
YuriPlyakhin marked this conversation as resolved.
Show resolved Hide resolved

Type *getInnermostType(Type *Ty) {
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
Expand Down Expand Up @@ -184,12 +188,73 @@ bool transformAccessChain(Function *F) {
}
return ModuleChanged;
}

// Per SPIR-V specification Layout of a matrix must be a constant instruction
// aka a constexpr or specialization constant. Meanwhile in SYCL headers
// layout is passed as a parameter to joint_matrix_load function, so even if
// that layout is a constant expression in the user's code - it's not possible
// to prove that to the compiler, so constant propagation will happen only
// after inlining, not in AST. That means, that with O0 layout would remain
// to be a runtime variable in LLVM IR.
// SYCL matrix layout is being mapped on SPIR-V matrix layout by
// joint_matrix_layout_to_spv function. The following routine finds calls to
// this function and replaces them with the found constant.
// This function also cleans up code, that becomes dead. Pattern of the dead
// code is stable, as user's code doesn't affect it.
bool propagateConstexprLayout(Function *F) {
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
bool ModuleChanged = false;
std::queue<Instruction *> ToErase;
for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
User *U = *I++;
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
auto *Op = dyn_cast<Instruction>(CI->getArgOperand(0));
if (!Op || !isa<LoadInst>(Op))
continue;
auto *Ptr = dyn_cast<Instruction>(cast<LoadInst>(Op)->getPointerOperand());
if (!Ptr)
continue;

ConstantInt *ConstLayout = nullptr;
for (const auto &U : Ptr->users()) {
if (!isa<StoreInst>(U))
continue;
assert(!ConstLayout && "More than 1 layout value was found");
auto *SI = cast<StoreInst>(U);
ConstLayout = dyn_cast<ConstantInt>(SI->getValueOperand());
if (ConstLayout) {
CI->replaceAllUsesWith(ConstLayout);
ToErase.push(CI);
ToErase.push(SI);
ModuleChanged = true;
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (ModuleChanged) {
ToErase.push(Op);
ToErase.push(Ptr);
if (auto *Cast = dyn_cast<AddrSpaceCastInst>(Ptr)) {
auto *OrigPtr = Cast->getPointerOperand();
if (auto *AI = dyn_cast<AllocaInst>(OrigPtr))
ToErase.push(AI);
}
}
while (!ToErase.empty()) {
ToErase.front()->dropAllReferences();
ToErase.front()->eraseFromParent();
ToErase.pop();
}
}
return ModuleChanged;
}
} // namespace

PreservedAnalyses
SYCLJointMatrixTransformPass::run(Module &M, ModuleAnalysisManager &MAM) {
bool ModuleChanged = false;
for (Function &F : M) {
if (F.getName() == MATRIX_LAYOUT)
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
ModuleChanged |= propagateConstexprLayout(&F);
if (!F.isDeclaration())
continue;
if (F.getName().starts_with(ACCESS_CHAIN))
Expand Down
68 changes: 68 additions & 0 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; The test checks, that unused call to __spirv_AccessChain is eliminated.
MrSidims marked this conversation as resolved.
Show resolved Hide resolved

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
target triple = "spir64-unknown-unknown"

$_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE = comdat any
MrSidims marked this conversation as resolved.
Show resolved Hide resolved

; CHECK: define weak_odr dso_local spir_kernel void @test
; CHECK-NEXT: entry:
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}}
; CHECK-NEXT: ret void

define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix, i64 noundef %stride) {
entry:
%layout = alloca i32, align 4
%layout.ascast = addrspacecast ptr %layout to ptr addrspace(4)
store i32 0, ptr addrspace(4) %layout.ascast, align 4
%layout.val = load i32, ptr addrspace(4) %layout.ascast, align 4
%layout.spv = call spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %layout.val)
%mload = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix, i32 noundef %layout.spv, i64 noundef %stride, i32 noundef 0)
ret void
}

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef, i32 noundef, i64 noundef, i32 noundef)

define linkonce_odr dso_local spir_func noundef i32 @_ZN4sycl3_V16detail26joint_matrix_layout_to_spvENS0_3ext6oneapi12experimental6matrix6layoutE(i32 noundef %Layout) comdat {
entry:
%retval = alloca i32, align 4
%Layout.addr = alloca i32, align 4
%retval.ascast = addrspacecast ptr %retval to ptr addrspace(4)
%Layout.addr.ascast = addrspacecast ptr %Layout.addr to ptr addrspace(4)
store i32 %Layout, ptr addrspace(4) %Layout.addr.ascast, align 4
%0 = load i32, ptr addrspace(4) %Layout.addr.ascast, align 4
switch i32 %0, label %sw.epilog [
i32 0, label %sw.bb
i32 1, label %sw.bb1
i32 2, label %sw.bb2
i32 3, label %sw.bb3
]

sw.bb: ; preds = %entry
store i32 0, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb1: ; preds = %entry
store i32 1, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb2: ; preds = %entry
store i32 2, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb3: ; preds = %entry
store i32 3, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.epilog: ; preds = %entry
call void @llvm.trap()
unreachable

return: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb
%1 = load i32, ptr addrspace(4) %retval.ascast, align 4
ret i32 %1
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ convertMatrixUseStringToEnum(const char *UseString) {
return std::nullopt;
}

inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(
constexpr inline __spv::MatrixLayout joint_matrix_layout_to_spv(
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
sycl::ext::oneapi::experimental::matrix::layout Layout) {
Copy link
Contributor

@dkhaldi dkhaldi Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it was not possible to capture and preserve the layout as constant at this stage?
That's why you had to create this preservation as a function in SYCLJointMatrixTransform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiler considers it to be a runtime value as joint_matrix_load/store functions (that calls joint_matrix_layout_to_spv) can not be in constexpr context, as they call external functions (__spirv_CooperativeMatrixLoad/StoreKHR). I'm leaving constexpr keyword here as it's really our intention, yet not enforced by the compiler.

switch (Layout) {
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
Expand Down
Loading