Skip to content

Commit

Permalink
[RISCV] Software-guard direct calls in large code model
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidTw committed Sep 20, 2024
1 parent 8308d47 commit b96996d
Show file tree
Hide file tree
Showing 5 changed files with 713 additions and 125 deletions.
27 changes: 24 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19696,11 +19696,14 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// If the callee is a GlobalAddress/ExternalSymbol node, turn it into a
// TargetGlobalAddress/TargetExternalSymbol node so that legalize won't
// split it and then direct call can be matched by PseudoCALL.
bool CalleeIsLargeExternalSymbol = false;
if (getTargetMachine().getCodeModel() == CodeModel::Large) {
if (auto *S = dyn_cast<GlobalAddressSDNode>(Callee))
Callee = getLargeGlobalAddress(S, DL, PtrVT, DAG);
else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee))
else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
Callee = getLargeExternalSymbol(S, DL, PtrVT, DAG);
CalleeIsLargeExternalSymbol = true;
}
} else if (GlobalAddressSDNode *S = dyn_cast<GlobalAddressSDNode>(Callee)) {
const GlobalValue *GV = S->getGlobal();
Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, RISCVII::MO_CALL);
Expand Down Expand Up @@ -19736,16 +19739,32 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// Emit the call.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);

// Use software guarded branch for large code model non-indirect calls
// Tail call to external symbol will have a null CLI.CB and we need another
// way to determine the callsite type
bool NeedSWGuarded = false;
if (getTargetMachine().getCodeModel() == CodeModel::Large &&
Subtarget.hasStdExtZicfilp() &&
((CLI.CB && !CLI.CB->isIndirectCall()) || CalleeIsLargeExternalSymbol))
NeedSWGuarded = true;

if (IsTailCall) {
MF.getFrameInfo().setHasTailCall();
SDValue Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
SDValue Ret;
if (NeedSWGuarded)
Ret = DAG.getNode(RISCVISD::SW_GUARDED_TAIL, DL, NodeTys, Ops);
else
Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
if (CLI.CFIType)
Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
return Ret;
}

Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
if (NeedSWGuarded)
Chain = DAG.getNode(RISCVISD::SW_GUARDED_CALL, DL, NodeTys, Ops);
else
Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
if (CLI.CFIType)
Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
Expand Down Expand Up @@ -20193,6 +20212,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(CZERO_EQZ)
NODE_NAME_CASE(CZERO_NEZ)
NODE_NAME_CASE(SW_GUARDED_BRIND)
NODE_NAME_CASE(SW_GUARDED_CALL)
NODE_NAME_CASE(SW_GUARDED_TAIL)
NODE_NAME_CASE(TUPLE_INSERT)
NODE_NAME_CASE(TUPLE_EXTRACT)
NODE_NAME_CASE(SF_VC_XV_SE)
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,12 @@ enum NodeType : unsigned {
CZERO_EQZ, // vt.maskc for XVentanaCondOps.
CZERO_NEZ, // vt.maskcn for XVentanaCondOps.

/// Software guarded BRIND node. Operand 0 is the chain operand and
/// operand 1 is the target address.
// Software guarded BRIND node. Operand 0 is the chain operand and
// operand 1 is the target address.
SW_GUARDED_BRIND,
// Software guarded calls for large code model
SW_GUARDED_CALL,
SW_GUARDED_TAIL,

SF_VC_XV_SE,
SF_VC_IV_SE,
Expand Down
22 changes: 19 additions & 3 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_CallSeqEnd,
def riscv_call : SDNode<"RISCVISD::CALL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_call : SDNode<"RISCVISD::SW_GUARDED_CALL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_ret_glue : SDNode<"RISCVISD::RET_GLUE", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
def riscv_sret_glue : SDNode<"RISCVISD::SRET_GLUE", SDTNone,
Expand All @@ -69,6 +72,9 @@ def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_tail : SDNode<"RISCVISD::SW_GUARDED_TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND",
SDTBrind, [SDNPHasChain]>;
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>;
Expand Down Expand Up @@ -1555,10 +1561,15 @@ let Predicates = [NoStdExtZicfilp] in
def PseudoCALLIndirect : Pseudo<(outs), (ins GPRJALR:$rs1),
[(riscv_call GPRJALR:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
let Predicates = [HasStdExtZicfilp] in
let Predicates = [HasStdExtZicfilp] in {
def PseudoCALLIndirectNonX7 : Pseudo<(outs), (ins GPRJALRNonX7:$rs1),
[(riscv_call GPRJALRNonX7:$rs1)]>,
[(riscv_call GPRJALRNonX7:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
// For large code model, non-indirect calls could be software-guarded
def PseudoCALLIndirectX7 : Pseudo<(outs), (ins GPRX7:$rs1),
[(riscv_sw_guarded_call GPRX7:$rs1)]>,
PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
}
}

let isBarrier = 1, isReturn = 1, isTerminator = 1 in
Expand All @@ -1579,10 +1590,15 @@ let Predicates = [NoStdExtZicfilp] in
def PseudoTAILIndirect : Pseudo<(outs), (ins GPRTC:$rs1),
[(riscv_tail GPRTC:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
let Predicates = [HasStdExtZicfilp] in
let Predicates = [HasStdExtZicfilp] in {
def PseudoTAILIndirectNonX7 : Pseudo<(outs), (ins GPRTCNonX7:$rs1),
[(riscv_tail GPRTCNonX7:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
// For large code model, non-indirect calls could be software-guarded
def PseudoTAILIndirectX7 : Pseudo<(outs), (ins GPRX7:$rs1),
[(riscv_sw_guarded_tail GPRX7:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;
}
}

def : Pat<(riscv_tail (iPTR tglobaladdr:$dst)),
Expand Down
Loading

0 comments on commit b96996d

Please sign in to comment.