Skip to content

Commit

Permalink
Reduce stack usage of jit
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg committed Sep 23, 2024
1 parent e5e2c64 commit 43350de
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 125 deletions.
26 changes: 13 additions & 13 deletions src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ extern "C" {
#define OffsetOfContextField(field) \
(static_cast<sljit_sw>(offsetof(ExecutionContext, field)))

#define OffsetOfStackTmp(type, field) \
(stackTmpStart + static_cast<sljit_sw>(offsetof(type, field)))

#if !(defined SLJIT_INDIRECT_CALL && SLJIT_INDIRECT_CALL)
#define GET_FUNC_ADDR(type, func) (reinterpret_cast<type>(func))
#else
Expand Down Expand Up @@ -219,6 +222,7 @@ CompileContext::CompileContext(Module* module, JITCompiler* compiler)
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
, shuffleOffset(0)
#endif /* SLJIT_CONFIG_X86 */
, stackTmpStart(0)
, nextTryBlock(0)
, currentTryBlock(InstanceConstData::globalTryBlock)
, trapBlocksStart(0)
Expand Down Expand Up @@ -1016,6 +1020,7 @@ JITCompiler::JITCompiler(Module* module, uint32_t JITFlags)
, m_options(0)
, m_savedIntegerRegCount(0)
, m_savedFloatRegCount(0)
, m_stackTmpSize(0)
{
if (module->m_jitModule != nullptr) {
ASSERT(module->m_jitModule->m_instanceConstData != nullptr);
Expand Down Expand Up @@ -1471,6 +1476,7 @@ void JITCompiler::clear()
m_first = nullptr;
m_last = nullptr;
m_branchTableSize = 0;
m_stackTmpSize = 0;
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
m_context.shuffleOffset = 0;
#endif /* SLJIT_CONFIG_X86 */
Expand Down Expand Up @@ -1507,17 +1513,15 @@ void JITCompiler::emitProlog()
options |= SLJIT_ENTER_USE_VEX;
#endif /* !SLJIT_CONFIG_X86 */

#if (defined SLJIT_CONFIG_ARM_32 && SLJIT_CONFIG_ARM_32)
ASSERT(m_stackTmpSize <= 32);
#else /* !SLJIT_CONFIG_ARM_32 */
ASSERT(m_stackTmpSize <= 16);
#endif /* SLJIT_CONFIG_ARM_32 */

sljit_emit_enter(m_compiler, options, SLJIT_ARGS0(P),
SLJIT_NUMBER_OF_SCRATCH_REGISTERS | SLJIT_ENTER_FLOAT(SLJIT_NUMBER_OF_SCRATCH_FLOAT_REGISTERS),
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), sizeof(ExecutionContext::CallFrame));

// Setup new frame.
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(lastFrame));

sljit_get_local_base(m_compiler, SLJIT_R1, 0, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(kContextReg), OffsetOfContextField(lastFrame), SLJIT_R1, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), offsetof(ExecutionContext::CallFrame, frameStart), kFrameReg, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), offsetof(ExecutionContext::CallFrame, prevFrame), SLJIT_R0, 0);
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), m_stackTmpSize);

m_context.branchTableOffset = 0;
size_t size = func.branchTableSize * sizeof(sljit_up);
Expand Down Expand Up @@ -1555,10 +1559,6 @@ void JITCompiler::emitEpilog()
m_context.earlyReturns.clear();
}

// Restore previous frame.
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_SP), offsetof(ExecutionContext::CallFrame, prevFrame));
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(kContextReg), OffsetOfContextField(lastFrame), SLJIT_R1, 0);

sljit_emit_return(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0);

m_context.emitSlowCases(m_compiler);
Expand Down
24 changes: 24 additions & 0 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ static void compileFunction(JITCompiler* compiler)
group = Instruction::Binary;
paramType = ParamTypes::ParamSrc2Dst;
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
compiler->increaseStackTmpSize(16);
info = Instruction::kIsCallback;
requiredInit = OTDivRemI64;
#else /* !SLJIT_32BIT_ARCHITECTURE */
Expand Down Expand Up @@ -790,6 +791,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt32FromFloat32Callback;
compiler->increaseStackTmpSize(4);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt32FromFloat32;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -808,6 +810,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt32FromFloat64Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt32FromFloat64;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -819,6 +822,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat32Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt64FromFloat32;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -830,6 +834,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrcDst;
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat32Callback;
compiler->increaseStackTmpSize(8);
break;
}
case ByteCode::I64TruncF64SOpcode: {
Expand All @@ -838,6 +843,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat64Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt64FromFloat64;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -849,6 +855,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrcDst;
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat64Callback;
compiler->increaseStackTmpSize(8);
break;
}
case ByteCode::I32TruncSatF32UOpcode: {
Expand All @@ -857,6 +864,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt32FromFloat32Callback;
compiler->increaseStackTmpSize(4);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt32FromFloat32;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -868,6 +876,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt32FromFloat64Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt32FromFloat64;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -879,6 +888,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat32Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt64FromFloat32;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -890,6 +900,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
requiredInit = OTConvertInt64FromFloat64Callback;
compiler->increaseStackTmpSize(8);
#else /* !SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertInt64FromFloat64;
#endif /* SLJIT_32BIT_ARCHITECTURE */
Expand All @@ -908,6 +919,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrcDst;
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(8);
#endif /* SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertFloat32FromInt64;
break;
Expand All @@ -925,6 +937,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrcDst;
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(8);
#endif /* SLJIT_32BIT_ARCHITECTURE */
requiredInit = OTConvertFloat64FromInt64;
break;
Expand Down Expand Up @@ -1235,6 +1248,7 @@ static void compileFunction(JITCompiler* compiler)
Instruction* instr = compiler->append(byteCode, Instruction::Table, opcode, 3, 0);
instr->addInfo(Instruction::kIsCallback);
instr->setRequiredRegsDescriptor(OTCallback3Arg);
compiler->increaseStackTmpSize(sizeof(InitTableArguments));

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(tableInit->srcOffsets()[0]);
Expand All @@ -1254,6 +1268,7 @@ static void compileFunction(JITCompiler* compiler)
Instruction* instr = compiler->append(byteCode, Instruction::Table, opcode, 3, 0);
instr->addInfo(Instruction::kIsCallback);
instr->setRequiredRegsDescriptor(OTCallback3Arg);
compiler->increaseStackTmpSize(sizeof(TableCopyArguments));

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(tableCopy->srcOffsets()[0]);
Expand All @@ -1267,6 +1282,7 @@ static void compileFunction(JITCompiler* compiler)
Instruction* instr = compiler->append(byteCode, Instruction::Table, opcode, 3, 0);
instr->addInfo(Instruction::kIsCallback);
instr->setRequiredRegsDescriptor(OTCallback3Arg);
compiler->increaseStackTmpSize(sizeof(TableFillArguments));

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(tableFill->srcOffsets()[0]);
Expand Down Expand Up @@ -1314,6 +1330,7 @@ static void compileFunction(JITCompiler* compiler)
Instruction* instr = compiler->append(byteCode, Instruction::Memory, opcode, 3, 0);
instr->addInfo(Instruction::kIsCallback);
instr->setRequiredRegsDescriptor(OTCallback3Arg);
compiler->increaseStackTmpSize(sizeof(MemoryInitArguments));

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(memoryInit->srcOffsets()[0]);
Expand Down Expand Up @@ -1663,6 +1680,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrc2Dst;
#if (defined SLJIT_CONFIG_ARM_32 && SLJIT_CONFIG_ARM_32)
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(32);
#endif /* SLJIT_CONFIG_ARM_32 */
requiredInit = OTMinMaxV128;
break;
Expand Down Expand Up @@ -1741,6 +1759,7 @@ static void compileFunction(JITCompiler* compiler)
paramType = ParamTypes::ParamSrcDst;
#if (defined SLJIT_CONFIG_ARM_32 && SLJIT_CONFIG_ARM_32)
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(16);
#endif /* SLJIT_CONFIG_ARM_32 */
requiredInit = OTOp1V128CB;
break;
Expand Down Expand Up @@ -1841,6 +1860,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
if (opcode == ByteCode::I64AtomicLoadOpcode) {
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(8);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
if (requiredInit == OTNone) {
Expand All @@ -1864,6 +1884,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
if (opcode == ByteCode::I64AtomicStoreOpcode) {
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(8);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
if (requiredInit == OTNone) {
Expand Down Expand Up @@ -1902,6 +1923,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
if (info == 0) {
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(16);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
FALLTHROUGH;
Expand Down Expand Up @@ -1947,6 +1969,7 @@ static void compileFunction(JITCompiler* compiler)
#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
if (info == 0) {
info = Instruction::kIsCallback;
compiler->increaseStackTmpSize(16);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
FALLTHROUGH;
Expand Down Expand Up @@ -1978,6 +2001,7 @@ static void compileFunction(JITCompiler* compiler)
ByteCodeOffset4Value* memoryAtomicWait = reinterpret_cast<ByteCodeOffset4Value*>(byteCode);
Operand* operands = instr->operands();
instr->setRequiredRegsDescriptor(requiredInit != OTNone ? requiredInit : OTAtomicWaitI32);
compiler->increaseStackTmpSize(16);

operands[0] = STACK_OFFSET(memoryAtomicWait->src0Offset());
operands[1] = STACK_OFFSET(memoryAtomicWait->src1Offset());
Expand Down
32 changes: 32 additions & 0 deletions src/jit/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BrTableInstruction;
class Label;
class JITModule;
struct CompileContext;
struct ExecutionContext;

// Defined in ObjectType.h.
class FunctionType;
Expand Down Expand Up @@ -548,6 +549,28 @@ struct TrapJump {
sljit_jump* jump;
};

struct MemoryInitArguments {
ExecutionContext* context;
uint32_t segmentIndex;
};

struct InitTableArguments {
ExecutionContext* context;
uint32_t tableIndex;
uint32_t segmentIndex;
};

struct TableCopyArguments {
ExecutionContext* context;
uint32_t srcIndex;
uint32_t dstIndex;
};

struct TableFillArguments {
ExecutionContext* context;
uint32_t tableIndex;
};

struct CompileContext {
CompileContext(Module* module, JITCompiler* compiler);

Expand All @@ -565,6 +588,7 @@ struct CompileContext {
size_t globalsStart;
size_t tableStart;
size_t functionsStart;
sljit_sw stackTmpStart;
size_t nextTryBlock;
size_t currentTryBlock;
size_t trapBlocksStart;
Expand Down Expand Up @@ -729,6 +753,13 @@ class JITCompiler {
m_branchTableSize += value;
}

void increaseStackTmpSize(uint8_t value)
{
if (m_stackTmpSize < value) {
m_stackTmpSize = value;
}
}

void setModuleFunction(ModuleFunction* moduleFunction)
{
m_moduleFunction = moduleFunction;
Expand Down Expand Up @@ -797,6 +828,7 @@ class JITCompiler {
uint32_t m_options;
uint8_t m_savedIntegerRegCount;
uint8_t m_savedFloatRegCount;
uint8_t m_stackTmpSize;

std::vector<TryBlock> m_tryBlocks;
std::vector<FunctionList> m_functionList;
Expand Down
12 changes: 7 additions & 5 deletions src/jit/FloatConvInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,10 +708,12 @@ static void emitConvertFloat(sljit_compiler* compiler, Instruction* instr)
}
#endif /* SLJIT_64BIT_ARCHITECTURE */

sljit_sw stackTmpStart = CompileContext::get(compiler)->stackTmpStart;

if (arg.arg == SLJIT_MEM1(kFrameReg)) {
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, kFrameReg, 0, SLJIT_IMM, arg.argw);
} else {
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, kContextReg, 0, SLJIT_IMM, OffsetOfContextField(tmp1));
sljit_get_local_base(compiler, SLJIT_R0, 0, stackTmpStart);
}

#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
Expand All @@ -737,14 +739,14 @@ static void emitConvertFloat(sljit_compiler* compiler, Instruction* instr)

#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
sljit_s32 movOp = (flags & DestinationIs64Bit) ? SLJIT_MOV : SLJIT_MOV32;
sljit_emit_op1(compiler, movOp, arg.arg, arg.argw, SLJIT_MEM1(kContextReg), OffsetOfContextField(tmp1));
sljit_emit_op1(compiler, movOp, arg.arg, arg.argw, SLJIT_MEM1(SLJIT_SP), stackTmpStart);
#else /* !SLJIT_64BIT_ARCHITECTURE */
if (!(flags & DestinationIs64Bit)) {
sljit_emit_op1(compiler, SLJIT_MOV, arg.arg, arg.argw, SLJIT_MEM1(kContextReg), OffsetOfContextField(tmp1));
sljit_emit_op1(compiler, SLJIT_MOV, arg.arg, arg.argw, SLJIT_MEM1(SLJIT_SP), stackTmpStart);
return;
}

sljit_emit_op1(compiler, SLJIT_MOV, argPair.arg1, argPair.arg1w, SLJIT_MEM1(kContextReg), OffsetOfContextField(tmp1) + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, argPair.arg2, argPair.arg2w, SLJIT_MEM1(kContextReg), OffsetOfContextField(tmp1) + WORD_HIGH_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, argPair.arg1, argPair.arg1w, SLJIT_MEM1(SLJIT_SP), stackTmpStart + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, argPair.arg2, argPair.arg2w, SLJIT_MEM1(SLJIT_SP), stackTmpStart + WORD_HIGH_OFFSET);
#endif /* SLJIT_32BIT_ARCHITECTURE */
}
Loading

0 comments on commit 43350de

Please sign in to comment.