Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Software guard direct calls in large code model #109377

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@ void RISCVMCCodeEmitter::expandFunctionCall(const MCInst &MI,
MCRegister Ra;
if (MI.getOpcode() == RISCV::PseudoTAIL) {
Func = MI.getOperand(0);
Ra = RISCV::X6;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems gratuitous?

// For Zicfilp, PseudoTAIL should be expanded to a software guarded branch.
// It means to use t2(x7) as rs1 of JALR to expand PseudoTAIL.
if (STI.hasFeature(RISCV::FeatureStdExtZicfilp))
Ra = RISCV::X7;
else
Ra = RISCV::X6;
} else if (MI.getOpcode() == RISCV::PseudoCALLReg) {
Func = MI.getOperand(1);
Ra = MI.getOperand(0).getReg();
Expand Down
23 changes: 20 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19696,11 +19696,14 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// If the callee is a GlobalAddress/ExternalSymbol node, turn it into a
// TargetGlobalAddress/TargetExternalSymbol node so that legalize won't
// split it and then direct call can be matched by PseudoCALL.
bool CalleeIsLargeExternalSymbol = false;
if (getTargetMachine().getCodeModel() == CodeModel::Large) {
if (auto *S = dyn_cast<GlobalAddressSDNode>(Callee))
Callee = getLargeGlobalAddress(S, DL, PtrVT, DAG);
else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee))
else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
Callee = getLargeExternalSymbol(S, DL, PtrVT, DAG);
CalleeIsLargeExternalSymbol = true;
}
} else if (GlobalAddressSDNode *S = dyn_cast<GlobalAddressSDNode>(Callee)) {
const GlobalValue *GV = S->getGlobal();
Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, RISCVII::MO_CALL);
Expand Down Expand Up @@ -19736,16 +19739,28 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// Emit the call.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);

// Use software guarded branch for large code model non-indirect calls
// Tail call to external symbol will have a null CLI.CB and we need another
// way to determine the callsite type
bool NeedSWGuarded = false;
if (getTargetMachine().getCodeModel() == CodeModel::Large &&
Subtarget.hasStdExtZicfilp() &&
((CLI.CB && !CLI.CB->isIndirectCall()) || CalleeIsLargeExternalSymbol))
NeedSWGuarded = true;

if (IsTailCall) {
MF.getFrameInfo().setHasTailCall();
SDValue Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
unsigned CallOpc =
NeedSWGuarded ? RISCVISD::SW_GUARDED_TAIL : RISCVISD::TAIL;
SDValue Ret = DAG.getNode(CallOpc, DL, NodeTys, Ops);
if (CLI.CFIType)
Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
return Ret;
}

Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
unsigned CallOpc = NeedSWGuarded ? RISCVISD::SW_GUARDED_CALL : RISCVISD::CALL;
Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops);
if (CLI.CFIType)
Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
Expand Down Expand Up @@ -20193,6 +20208,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(CZERO_EQZ)
NODE_NAME_CASE(CZERO_NEZ)
NODE_NAME_CASE(SW_GUARDED_BRIND)
NODE_NAME_CASE(SW_GUARDED_CALL)
NODE_NAME_CASE(SW_GUARDED_TAIL)
NODE_NAME_CASE(TUPLE_INSERT)
NODE_NAME_CASE(TUPLE_EXTRACT)
NODE_NAME_CASE(SF_VC_XV_SE)
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,12 @@ enum NodeType : unsigned {
CZERO_EQZ, // vt.maskc for XVentanaCondOps.
CZERO_NEZ, // vt.maskcn for XVentanaCondOps.

/// Software guarded BRIND node. Operand 0 is the chain operand and
/// operand 1 is the target address.
// Software guarded BRIND node. Operand 0 is the chain operand and
// operand 1 is the target address.
SW_GUARDED_BRIND,
// Software guarded calls for large code model
SW_GUARDED_CALL,
SW_GUARDED_TAIL,

SF_VC_XV_SE,
SF_VC_IV_SE,
Expand Down
20 changes: 18 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_CallSeqEnd,
def riscv_call : SDNode<"RISCVISD::CALL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_call : SDNode<"RISCVISD::SW_GUARDED_CALL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_ret_glue : SDNode<"RISCVISD::RET_GLUE", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
def riscv_sret_glue : SDNode<"RISCVISD::SRET_GLUE", SDTNone,
Expand All @@ -69,6 +72,9 @@ def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_tail : SDNode<"RISCVISD::SW_GUARDED_TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND",
SDTBrind, [SDNPHasChain]>;
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>;
Expand Down Expand Up @@ -1555,10 +1561,15 @@ let Predicates = [NoStdExtZicfilp] in
def PseudoCALLIndirect : Pseudo<(outs), (ins GPRJALR:$rs1),
[(riscv_call GPRJALR:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
let Predicates = [HasStdExtZicfilp] in
let Predicates = [HasStdExtZicfilp] in {
def PseudoCALLIndirectNonX7 : Pseudo<(outs), (ins GPRJALRNonX7:$rs1),
[(riscv_call GPRJALRNonX7:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
// For large code model, non-indirect calls could be software-guarded
def PseudoCALLIndirectX7 : Pseudo<(outs), (ins GPRX7:$rs1),
[(riscv_sw_guarded_call GPRX7:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
}
}

let isBarrier = 1, isReturn = 1, isTerminator = 1 in
Expand All @@ -1579,10 +1590,15 @@ let Predicates = [NoStdExtZicfilp] in
def PseudoTAILIndirect : Pseudo<(outs), (ins GPRTC:$rs1),
[(riscv_tail GPRTC:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
let Predicates = [HasStdExtZicfilp] in
let Predicates = [HasStdExtZicfilp] in {
def PseudoTAILIndirectNonX7 : Pseudo<(outs), (ins GPRTCNonX7:$rs1),
[(riscv_tail GPRTCNonX7:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
// For large code model, non-indirect calls could be software-guarded
def PseudoTAILIndirectX7 : Pseudo<(outs), (ins GPRX7:$rs1),
[(riscv_sw_guarded_tail GPRX7:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
}
}

def : Pat<(riscv_tail (iPTR tglobaladdr:$dst)),
Expand Down
Loading
Loading