Skip to content

Commit

Permalink
[clang][llvm][aarch64] Add aarch64_sme_in_streaming_mode intrinsic (l…
Browse files Browse the repository at this point in the history
…lvm#120265)

Replacing the extant streaming mode function call with an intrinsic
allows us to make further optimisations around it. For example, if it's
called within a function that has a known streaming mode, we can remove
the dead code, and avoid the redundant conditional branch.
  • Loading branch information
NickGuy-Arm authored Jan 7, 2025
1 parent 064da42 commit 21b531e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 32 deletions.
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/arm_sme.td
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ let SMETargetGuard = "sme2" in {
def SVZERO_ZT : Inst<"svzero_zt", "vi", "", MergeNone, "aarch64_sme_zero_zt", [IsOverloadNone, IsStreamingCompatible, IsOutZT0], [ImmCheck<0, ImmCheck0_0>]>;
}

def IN_STREAMING_MODE : Inst<"__arm_in_streaming_mode", "sv", "Pc", MergeNone, "aarch64_sme_in_streaming_mode", [IsOverloadNone, IsStreamingCompatible], []>;

//
// lookup table expand four contiguous registers
//
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11327,6 +11327,19 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
if (Builtin->LLVMIntrinsic == 0)
return nullptr;

if (BuiltinID == SME::BI__builtin_sme___arm_in_streaming_mode) {
// If we already know the streaming mode, don't bother with the intrinsic
// and emit a constant instead
const auto *FD = cast<FunctionDecl>(CurFuncDecl);
if (const auto *FPT = FD->getType()->getAs<FunctionProtoType>()) {
unsigned SMEAttrs = FPT->getAArch64SMEAttributes();
if (!(SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)) {
bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask;
return ConstantInt::getBool(Builder.getContext(), IsStreaming);
}
}
}

// Predicates must match the main datatype.
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (auto PredTy = dyn_cast<llvm::VectorType>(Ops[i]->getType()))
Expand Down
69 changes: 44 additions & 25 deletions clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,53 @@

#include <arm_sme.h>

// CHECK-LABEL: @test_in_streaming_mode(
// CHECK-LABEL: @test_in_streaming_mode_streaming_compatible(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR3:[0-9]+]]
// CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CHECK-NEXT: [[AND_I:%.*]] = and i64 [[TMP1]], 1
// CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp ne i64 [[AND_I]], 0
// CHECK-NEXT: ret i1 [[TOBOOL_I]]
// CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
// CHECK-NEXT: ret i1 [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z22test_in_streaming_modev(
// CPP-CHECK-LABEL: @_Z43test_in_streaming_mode_streaming_compatiblev(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR3:[0-9]+]]
// CPP-CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CPP-CHECK-NEXT: [[AND_I:%.*]] = and i64 [[TMP1]], 1
// CPP-CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp ne i64 [[AND_I]], 0
// CPP-CHECK-NEXT: ret i1 [[TOBOOL_I]]
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
// CPP-CHECK-NEXT: ret i1 [[TMP0]]
//
bool test_in_streaming_mode_streaming_compatible(void) __arm_streaming_compatible {
return __arm_in_streaming_mode();
}

// CHECK-LABEL: @test_in_streaming_mode_streaming(
// CHECK-NEXT: entry:
// CHECK-NEXT: ret i1 true
//
// CPP-CHECK-LABEL: @_Z32test_in_streaming_mode_streamingv(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: ret i1 true
//
bool test_in_streaming_mode_streaming(void) __arm_streaming {
//
return __arm_in_streaming_mode();
}

// CHECK-LABEL: @test_in_streaming_mode_non_streaming(
// CHECK-NEXT: entry:
// CHECK-NEXT: ret i1 false
//
// CPP-CHECK-LABEL: @_Z36test_in_streaming_mode_non_streamingv(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: ret i1 false
//
bool test_in_streaming_mode(void) __arm_streaming_compatible {
bool test_in_streaming_mode_non_streaming(void) {
return __arm_in_streaming_mode();
}

// CHECK-LABEL: @test_za_disable(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR3]]
// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]]
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: @_Z15test_za_disablev(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR3]]
// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]]
// CPP-CHECK-NEXT: ret void
//
void test_za_disable(void) __arm_streaming_compatible {
Expand All @@ -42,14 +61,14 @@ void test_za_disable(void) __arm_streaming_compatible {

// CHECK-LABEL: @test_has_sme(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR3]]
// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]]
// CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0
// CHECK-NEXT: ret i1 [[TOBOOL_I]]
//
// CPP-CHECK-LABEL: @_Z12test_has_smev(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR3]]
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]]
// CPP-CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CPP-CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0
// CPP-CHECK-NEXT: ret i1 [[TOBOOL_I]]
Expand All @@ -72,12 +91,12 @@ void test_svundef_za(void) __arm_streaming_compatible __arm_out("za") {

// CHECK-LABEL: @test_sc_memcpy(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memcpyPvPKvm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_compatible {
Expand All @@ -86,12 +105,12 @@ void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_comp

// CHECK-LABEL: @test_sc_memmove(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z15test_sc_memmovePvPKvm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_compatible {
Expand All @@ -100,12 +119,12 @@ void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_com

// CHECK-LABEL: @test_sc_memset(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memsetPvim(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible {
Expand All @@ -114,12 +133,12 @@ void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible {

// CHECK-LABEL: @test_sc_memchr(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memchrPvim(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR3]]
// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memchr(void *s, int c, size_t n) __arm_streaming_compatible {
Expand Down
7 changes: 0 additions & 7 deletions clang/utils/TableGen/SveEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1640,13 +1640,6 @@ void SVEEmitter::createSMEHeader(raw_ostream &OS) {
OS << " return x0 & (1ULL << 63);\n";
OS << "}\n\n";

OS << "__ai bool __arm_in_streaming_mode(void) __arm_streaming_compatible "
"{\n";
OS << " uint64_t x0, x1;\n";
OS << " __builtin_arm_get_sme_state(&x0, &x1);\n";
OS << " return x0 & 1;\n";
OS << "}\n\n";

OS << "void *__arm_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_compatible;\n";
OS << "void *__arm_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_compatible;\n";
OS << "void *__arm_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible;\n";
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -2974,6 +2974,7 @@ let TargetPrefix = "aarch64" in {


def int_aarch64_sme_zero : DefaultAttrsIntrinsic<[], [llvm_i32_ty], [ImmArg<ArgIndex<0>>]>;
def int_aarch64_sme_in_streaming_mode : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrNoMem]>, ClangBuiltin<"__builtin_arm_in_streaming_mode">;

class SME_OuterProduct_Intrinsic
: DefaultAttrsIntrinsic<[],
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setMaxDivRemBitWidthSupported(128);

setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
if (Subtarget->hasSME())
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i1, Custom);

if (Subtarget->isNeonAvailable()) {
// FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
Expand Down Expand Up @@ -27429,6 +27431,15 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
case Intrinsic::aarch64_sme_in_streaming_mode: {
SDLoc DL(N);
SDValue Chain = DAG.getEntryNode();
SDValue RuntimePStateSM =
getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
Results.push_back(
DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM));
return;
}
case Intrinsic::experimental_vector_match:
case Intrinsic::get_active_lane_mask: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
Expand Down
18 changes: 18 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-intrinsics-state.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s


define i1 @streaming_mode_streaming_compatible() #0 {
; CHECK-LABEL: streaming_mode_streaming_compatible:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and w0, w0, #0x1
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%mode = tail call noundef i1 @llvm.aarch64.sme.in.streaming.mode()
ret i1 %mode
}


attributes #0 = {nounwind memory(none) "aarch64_pstate_sm_compatible"}

0 comments on commit 21b531e

Please sign in to comment.