Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Matrix] Propagate constexpr matrix layout even with O0 #16628

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace {

static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
static constexpr char MATRIX_LAYOUT[] = "joint_matrix_layout_to_spv";

Type *getInnermostType(Type *Ty) {
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
Expand Down Expand Up @@ -184,17 +185,99 @@ bool transformAccessChain(Function *F) {
}
return ModuleChanged;
}

StoreInst *findLastStoreBeforeLoad(Value *Ptr, Instruction *Load) {
BasicBlock::iterator It(Load);
while (It != Load->getParent()->begin()) {
--It;
if (auto *Store = dyn_cast<StoreInst>(&*It))
if (Store->getPointerOperand() == Ptr)
return Store;
}
return nullptr;
}

// Per SPIR-V specification Layout of a matrix must be a constant instruction
// aka a constexpr or specialization constant. Meanwhile in SYCL headers
// layout is passed as a parameter to joint_matrix_load function, so even if
// that layout is a constant expression in the user's code - it's not possible
// to prove that to the compiler, so constant propagation will happen only
// after inlining, not in AST. That means, that with O0 layout would remain
// to be a runtime variable in LLVM IR.
// SYCL matrix layout is being mapped on SPIR-V matrix layout by
// joint_matrix_layout_to_spv function. The following routine finds calls to
// this function and replaces them with the found constant.
// This function also cleans up code, that becomes dead. Pattern of the dead
// code is stable, as user's code doesn't affect it.
bool propagateConstexprLayout(Function *F) {
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
llvm::SmallVector<Instruction *, 8> ToErase;
for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
User *U = *I++;
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
auto *Op = dyn_cast<Instruction>(CI->getArgOperand(0));
if (!Op || !isa<LoadInst>(Op))
continue;
auto *Ptr = dyn_cast<Instruction>(cast<LoadInst>(Op)->getPointerOperand());
if (!Ptr)
continue;

ConstantInt *ConstLayout = nullptr;
StoreInst *SI = findLastStoreBeforeLoad(Ptr, Op);
if (!SI)
continue;
ConstLayout = dyn_cast<ConstantInt>(SI->getValueOperand());
if (ConstLayout) {
CI->replaceAllUsesWith(ConstLayout);
ToErase.push_back(CI);
ToErase.push_back(SI);
ToErase.push_back(Op);
ToErase.push_back(Ptr);
if (auto *Cast = dyn_cast<AddrSpaceCastInst>(Ptr)) {
auto *OrigPtr = Cast->getPointerOperand();
if (auto *AI = dyn_cast<AllocaInst>(OrigPtr))
ToErase.push_back(AI);
}
}
}

// There are possible cases, when a single instruction result is used multiple
// times. For this case we have to use a vector to store such instructions
// and keep track if we have removed them before to avoid double free().
SmallPtrSet<Instruction *, 8> Erased;
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
for (Instruction *II : ToErase) {
if (!II->use_empty())
continue;
if (Erased.contains(II))
continue;
II->dropAllReferences();
II->eraseFromParent();
Erased.insert(II);
}
return !ToErase.empty();
}
} // namespace

PreservedAnalyses
SYCLJointMatrixTransformPass::run(Module &M, ModuleAnalysisManager &MAM) {
bool ModuleChanged = false;
llvm::SmallVector<Function *, 1> ToErase;
for (Function &F : M) {
if (!F.isDeclaration())
continue;
if (!F.isDeclaration()) {
if (F.getName() == MATRIX_LAYOUT) {
ModuleChanged |= propagateConstexprLayout(&F);
ToErase.push_back(&F);
} else
continue;
}
if (F.getName().starts_with(ACCESS_CHAIN))
ModuleChanged |= transformAccessChain(&F);
}

for (auto *F : ToErase)
if (F->users().empty())
F->eraseFromParent();

return ModuleChanged ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
90 changes: 90 additions & 0 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/constexpr-layout.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
; The test checks, that users of the call to joint_matrix_layout_to_spv matrix
; are replaced with the layout constant.

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
target triple = "spir64-unknown-unknown"

$joint_matrix_layout_to_spv = comdat any

; CHECK: define weak_odr dso_local spir_kernel void @test
; CHECK-NEXT: entry:
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}}
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}}
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}}
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 2, i64 noundef{{.*}}
; CHECK-NEXT: ret void

; CHECK-NOT: joint_matrix_layout_to_spv

define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix.1, ptr addrspace(1) %matrix.2, i64 noundef %stride) {
entry:
%layout.1 = alloca i32, align 4
%layout.2 = alloca i32, align 4
%layout.ascast.1 = addrspacecast ptr %layout.1 to ptr addrspace(4)
%layout.ascast.2 = addrspacecast ptr %layout.2 to ptr addrspace(4)
store i32 0, ptr addrspace(4) %layout.ascast.1, align 4
store i32 1, ptr addrspace(4) %layout.ascast.2, align 4

%layout.val.1 = load i32, ptr addrspace(4) %layout.ascast.1, align 4
%layout.spv.1 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.1)
%mload.1 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.1, i32 noundef %layout.spv.1, i64 noundef %stride, i32 noundef 0)

%layout.val.2 = load i32, ptr addrspace(4) %layout.ascast.2, align 4
%layout.spv.2 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2)
%mload.2 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.2, i64 noundef %stride, i32 noundef 0)

%layout.spv.3 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2)
%mload.3 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.3, i64 noundef %stride, i32 noundef 0)

store i32 2, ptr addrspace(4) %layout.ascast.2, align 4
%layout.val.4 = load i32, ptr addrspace(4) %layout.ascast.2, align 4
%layout.spv.4 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.4)
%mload.4 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.4, i64 noundef %stride, i32 noundef 0)
ret void
}

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef, i32 noundef, i64 noundef, i32 noundef)

define linkonce_odr dso_local spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %Layout) comdat {
entry:
%retval = alloca i32, align 4
%Layout.addr = alloca i32, align 4
%retval.ascast = addrspacecast ptr %retval to ptr addrspace(4)
%Layout.addr.ascast = addrspacecast ptr %Layout.addr to ptr addrspace(4)
store i32 %Layout, ptr addrspace(4) %Layout.addr.ascast, align 4
%0 = load i32, ptr addrspace(4) %Layout.addr.ascast, align 4
switch i32 %0, label %sw.epilog [
i32 0, label %sw.bb
i32 1, label %sw.bb1
i32 2, label %sw.bb2
i32 3, label %sw.bb3
]

sw.bb: ; preds = %entry
store i32 0, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb1: ; preds = %entry
store i32 1, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb2: ; preds = %entry
store i32 2, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.bb3: ; preds = %entry
store i32 3, ptr addrspace(4) %retval.ascast, align 4
br label %return

sw.epilog: ; preds = %entry
call void @llvm.trap()
unreachable

return: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb
%1 = load i32, ptr addrspace(4) %retval.ascast, align 4
ret i32 %1
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ convertMatrixUseStringToEnum(const char *UseString) {
return std::nullopt;
}

inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(
// propagateConstexprLayout uses the exact name of the function, so we use
// extern "C" here.
extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need the inlining attribute anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as constexpr also includes inline, see the comment from Alexey

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah...good...sorry..missed that as it was hiding.

sycl::ext::oneapi::experimental::matrix::layout Layout) {
Copy link
Contributor

@dkhaldi dkhaldi Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it was not possible to capture and preserve the layout as constant at this stage?
That's why you had to create this preservation as a function in SYCLJointMatrixTransform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiler considers it to be a runtime value as joint_matrix_load/store functions (that calls joint_matrix_layout_to_spv) can not be in constexpr context, as they call external functions (__spirv_CooperativeMatrixLoad/StoreKHR). I'm leaving constexpr keyword here as it's really our intention, yet not enforced by the compiler.

switch (Layout) {
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
Expand Down
Loading