Skip to content

Commit

Permalink
Cache instance instead of context in a saved register
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg authored and clover2123 committed Sep 27, 2024
1 parent ba0ad52 commit 4e655cb
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 89 deletions.
41 changes: 20 additions & 21 deletions src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ extern "C" {

namespace Walrus {

static const uint8_t kContextReg = SLJIT_S0;
static const uint8_t kFrameReg = SLJIT_S1;
static const uint8_t kFrameReg = SLJIT_S0;
static const uint8_t kInstanceReg = SLJIT_S1;
static const sljit_sw kContextOffset = 0;

struct JITArg {
JITArg(Operand* operand)
Expand Down Expand Up @@ -223,7 +224,7 @@ CompileContext::CompileContext(Module* module, JITCompiler* compiler)
, shuffleOffset(0)
#endif /* SLJIT_CONFIG_X86 */
, stackTmpStart(0)
, stackMemoryStart(0)
, stackMemoryStart(sizeof(sljit_sw))
, nextTryBlock(0)
, currentTryBlock(InstanceConstData::globalTryBlock)
, trapBlocksStart(0)
Expand Down Expand Up @@ -874,8 +875,7 @@ static void emitGlobalGet32(sljit_compiler* compiler, Instruction* instr)
GlobalGet32* globalGet = reinterpret_cast<GlobalGet32*>(instr->byteCode());
JITArg dstArg(instr->operands());

sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(SLJIT_TMP_MEM_REG), context->globalsStart + globalGet->index() * sizeof(void*));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalGet->index() * sizeof(void*));

if (instr->info() & Instruction::kHasFloatOperand) {
moveFloatToDest(compiler, SLJIT_MOV_F32, dstArg, JITFieldAccessor::globalValueOffset());
Expand All @@ -899,7 +899,7 @@ static void emitGlobalSet32(sljit_compiler* compiler, Instruction* instr)
baseReg = instr->requiredReg(0);
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (SLJIT_IS_MEM(src.arg)) {
if (instr->info() & Instruction::kHasFloatOperand) {
Expand All @@ -912,8 +912,6 @@ static void emitGlobalSet32(sljit_compiler* compiler, Instruction* instr)
src.argw = 0;
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(baseReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (instr->info() & Instruction::kHasFloatOperand) {
sljit_emit_fop1(compiler, SLJIT_MOV_F32, SLJIT_MEM1(baseReg), JITFieldAccessor::globalValueOffset(), src.arg, src.argw);
} else {
Expand All @@ -927,7 +925,7 @@ static void emitRefFunc(sljit_compiler* compiler, Instruction* instr)

CompileContext* context = CompileContext::get(compiler);

sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_MEM_REG, 0, kInstanceReg, 0);
moveIntToDest(compiler, SLJIT_MOV_P, dstArg, context->functionsStart + (sizeof(Function*) * (reinterpret_cast<RefFunc*>(instr->byteCode()))->funcIndex()));
}

Expand Down Expand Up @@ -1040,7 +1038,7 @@ void JITCompiler::compileFunction(JITFunction* jitFunc, bool isExternal)

m_functionList.push_back(FunctionList(jitFunc, isExternal, m_branchTableSize));

sljit_uw stackTmpStart = m_useMemory0 ? sizeof(Memory::TargetBuffer) : 0;
sljit_uw stackTmpStart = m_context.stackMemoryStart + (m_useMemory0 ? sizeof(Memory::TargetBuffer) : 0);
// Align data.
m_context.stackTmpStart = static_cast<sljit_sw>((stackTmpStart + sizeof(sljit_sw) - 1) & ~(sizeof(sljit_sw) - 1));

Expand All @@ -1051,10 +1049,11 @@ void JITCompiler::compileFunction(JITFunction* jitFunc, bool isExternal)

if (module()->m_jitModule == nullptr) {
// Follows the declaration of FunctionDescriptor::ExternalDecl().
// Context stored in SLJIT_S0 (kContextReg)
// Frame stored in SLJIT_S1 (kFrameReg)
sljit_emit_enter(m_compiler, 0, SLJIT_ARGS3(P, P, P, P_R), 3, 2, 0);
sljit_emit_icall(m_compiler, SLJIT_CALL_REG_ARG, SLJIT_ARGS0(P), SLJIT_R2, 0);
// Frame stored in SLJIT_S0 (kFrameReg)
// Instance stored in SLJIT_S1 (kInstanceReg)
sljit_emit_enter(m_compiler, 0, SLJIT_ARGS3(P, P_R, P, P_R), 3, 2, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, kInstanceReg, 0, SLJIT_MEM1(SLJIT_R0), OffsetOfContextField(instance));
sljit_emit_icall(m_compiler, SLJIT_CALL_REG_ARG, SLJIT_ARGS1(P, P), SLJIT_R2, 0);
sljit_label* returnToLabel = sljit_emit_label(m_compiler);
sljit_emit_return(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0);

Expand Down Expand Up @@ -1526,16 +1525,16 @@ void JITCompiler::emitProlog()
ASSERT(m_stackTmpSize <= 16);
#endif /* SLJIT_CONFIG_ARM_32 */

sljit_emit_enter(m_compiler, options, SLJIT_ARGS0(P),
sljit_emit_enter(m_compiler, options, SLJIT_ARGS1(P, P_R),
SLJIT_NUMBER_OF_SCRATCH_REGISTERS | SLJIT_ENTER_FLOAT(SLJIT_NUMBER_OF_SCRATCH_FLOAT_REGISTERS),
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), m_context.stackTmpStart + m_stackTmpSize);

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), kContextOffset, SLJIT_R0, 0);
if (hasMemory0()) {
sljit_sw stackMemoryStart = m_context.stackMemoryStart;
ASSERT(m_context.stackTmpStart >= stackMemoryStart + static_cast<sljit_sw>(sizeof(Memory::TargetBuffer)));

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_R0), Instance::alignedSize());
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(kInstanceReg), Instance::alignedSize());

sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_targetBuffers));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_sizeInByte) + WORD_LOW_OFFSET);
Expand Down Expand Up @@ -1577,9 +1576,8 @@ void JITCompiler::emitRestoreMemories()

sljit_sw stackMemoryStart = m_context.stackMemoryStart;

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(kInstanceReg), Instance::alignedSize());
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, prev));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_R1), Instance::alignedSize());
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_R1), offsetof(Memory, m_targetBuffers), SLJIT_R2, 0);
}

Expand Down Expand Up @@ -1640,7 +1638,8 @@ void JITCompiler::emitEpilog()

if (trapJumpIndex > 0 || (trapJumps.size() > 0 && trapJumps[0].jumpType == ExecutionContext::GenericTrap)) {
lastLabel = sljit_emit_label(m_compiler);
sljit_emit_op1(m_compiler, SLJIT_MOV_U32, SLJIT_MEM1(kContextReg), OffsetOfContextField(error), SLJIT_R0, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_SP), kContextOffset);
sljit_emit_op1(m_compiler, SLJIT_MOV_U32, SLJIT_MEM1(SLJIT_R1), OffsetOfContextField(error), SLJIT_R0, 0);

for (size_t i = 0; i < jumpCount; i++) {
sljit_set_label(jumps[i], lastLabel);
Expand All @@ -1658,8 +1657,8 @@ void JITCompiler::emitEpilog()
if (trapJumps.size() > 0 || m_tryBlockStart < m_tryBlocks.size()) {
lastLabel = sljit_emit_label(m_compiler);

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_SP), kContextOffset);
sljit_emit_op_dst(m_compiler, SLJIT_GET_RETURN_ADDRESS, SLJIT_R1, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, kContextReg, 0);
sljit_emit_icall(m_compiler, SLJIT_CALL, SLJIT_ARGS2(W, W, W), SLJIT_IMM, GET_FUNC_ADDR(sljit_sw, getTrapHandler));

emitRestoreMemories();
Expand Down
2 changes: 1 addition & 1 deletion src/jit/CallInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ static void emitCall(sljit_compiler* compiler, Instruction* instr)

sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_R0, 0, SLJIT_IMM, reinterpret_cast<sljit_sw>(instr->byteCode()));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, kFrameReg, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, kContextReg, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_SP), kContextOffset);

sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS3(W, W, W, W), SLJIT_IMM, addr);

Expand Down
8 changes: 4 additions & 4 deletions src/jit/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,24 +550,24 @@ struct TrapJump {
};

struct MemoryInitArguments {
ExecutionContext* context;
Instance* instance;
uint32_t segmentIndex;
};

struct InitTableArguments {
ExecutionContext* context;
Instance* instance;
uint32_t tableIndex;
uint32_t segmentIndex;
};

struct TableCopyArguments {
ExecutionContext* context;
Instance* instance;
uint32_t srcIndex;
uint32_t dstIndex;
};

struct TableFillArguments {
ExecutionContext* context;
Instance* instance;
uint32_t tableIndex;
};

Expand Down
9 changes: 3 additions & 6 deletions src/jit/IntMath32Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1196,8 +1196,7 @@ static void emitGlobalGet64(sljit_compiler* compiler, Instruction* instr)
GlobalGet64* globalGet = reinterpret_cast<GlobalGet64*>(instr->byteCode());
sljit_s32 baseReg = (instr->info() & Instruction::kHasFloatOperand) ? SLJIT_TMP_MEM_REG : instr->requiredReg(0);

sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(baseReg), context->globalsStart + globalGet->index() * sizeof(void*));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalGet->index() * sizeof(void*));

if (instr->info() & Instruction::kHasFloatOperand) {
JITArg dstArg(instr->operands());
Expand Down Expand Up @@ -1231,31 +1230,29 @@ static void emitGlobalSet64(sljit_compiler* compiler, Instruction* instr)
floatOperandToArg(compiler, instr->operands(), src, SLJIT_TMP_DEST_FREG);
sljit_s32 baseReg = SLJIT_TMP_MEM_REG;

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (SLJIT_IS_MEM(src.arg)) {
sljit_emit_fop1(compiler, SLJIT_MOV_F64, SLJIT_TMP_DEST_FREG, 0, src.arg, src.argw);
src.arg = SLJIT_TMP_DEST_FREG;
src.argw = 0;
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(baseReg), context->globalsStart + globalSet->index() * sizeof(void*));
sljit_emit_fop1(compiler, SLJIT_MOV_F64, SLJIT_MEM1(baseReg), JITFieldAccessor::globalValueOffset(), src.arg, src.argw);
return;
}

JITArgPair src(instr->operands());
sljit_s32 baseReg = instr->requiredReg(0);

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (SLJIT_IS_MEM(src.arg1)) {
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, src.arg1, src.arg1w);
src.arg1 = SLJIT_TMP_DEST_REG;
src.arg1w = 0;
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(baseReg), context->globalsStart + globalSet->index() * sizeof(void*));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(baseReg), JITFieldAccessor::globalValueOffset() + WORD_LOW_OFFSET, src.arg1, src.arg1w);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(baseReg), JITFieldAccessor::globalValueOffset() + WORD_HIGH_OFFSET, src.arg2, src.arg2w);
}
7 changes: 2 additions & 5 deletions src/jit/IntMath64Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,7 @@ static void emitGlobalGet64(sljit_compiler* compiler, Instruction* instr)
GlobalGet64* globalGet = reinterpret_cast<GlobalGet64*>(instr->byteCode());
JITArg dstArg(instr->operands());

sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(SLJIT_TMP_MEM_REG), context->globalsStart + globalGet->index() * sizeof(void*));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalGet->index() * sizeof(void*));

if (instr->info() & Instruction::kHasFloatOperand) {
moveFloatToDest(compiler, SLJIT_MOV_F64, dstArg, JITFieldAccessor::globalValueOffset());
Expand All @@ -562,7 +561,7 @@ static void emitGlobalSet64(sljit_compiler* compiler, Instruction* instr)
baseReg = instr->requiredReg(0);
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(kInstanceReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (SLJIT_IS_MEM(src.arg)) {
if (instr->info() & Instruction::kHasFloatOperand) {
Expand All @@ -575,8 +574,6 @@ static void emitGlobalSet64(sljit_compiler* compiler, Instruction* instr)
src.argw = 0;
}

sljit_emit_op1(compiler, SLJIT_MOV, baseReg, 0, SLJIT_MEM1(baseReg), context->globalsStart + globalSet->index() * sizeof(void*));

if (instr->info() & Instruction::kHasFloatOperand) {
sljit_emit_fop1(compiler, SLJIT_MOV_F64, SLJIT_MEM1(baseReg), JITFieldAccessor::globalValueOffset(), src.arg, src.argw);
} else {
Expand Down
11 changes: 5 additions & 6 deletions src/jit/MemoryInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,8 @@ static void emitAtomicRmwCmpxchg64(sljit_compiler* compiler, Instruction* instr)
sljit_emit_icall(compiler, SLJIT_CALL, type, SLJIT_IMM, functionAddr);

if (srcExpectedArgPair.arg1 != SLJIT_MEM1(kFrameReg)) {
sljit_emit_op1(compiler, SLJIT_MOV, dstArgPair.arg1, dstArgPair.arg1w, SLJIT_MEM1(kContextReg), stackTmpStart + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, dstArgPair.arg2, dstArgPair.arg2w, SLJIT_MEM1(kContextReg), stackTmpStart + WORD_HIGH_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, dstArgPair.arg1, dstArgPair.arg1w, SLJIT_MEM1(SLJIT_SP), stackTmpStart + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, dstArgPair.arg2, dstArgPair.arg2w, SLJIT_MEM1(SLJIT_SP), stackTmpStart + WORD_HIGH_OFFSET);
}
}

Expand Down Expand Up @@ -1613,7 +1613,7 @@ static void emitAtomicWait(sljit_compiler* compiler, Instruction* instr)
#endif /* SLJIT_32BIT_ARCHITECTURE */

sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_EXTRACT_REG(addr.memArg.arg), 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, kContextReg, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_SP), kContextOffset);
sljit_get_local_base(compiler, SLJIT_R2, 0, stackTmpStart);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R3, 0, SLJIT_IMM, size);

Expand All @@ -1626,9 +1626,8 @@ static void emitAtomicWait(sljit_compiler* compiler, Instruction* instr)
sljit_emit_op1(compiler, SLJIT_MOV, dst.arg, dst.argw, SLJIT_MEM1(SLJIT_SP), stackTmpStart + WORD_LOW_OFFSET);
}

static sljit_s32 atomicNotifyCallback(ExecutionContext* context, uint8_t* address, int32_t count)
static sljit_s32 atomicNotifyCallback(Instance* instance, uint8_t* address, int32_t count)
{
Instance* instance = context->instance;
uint32_t result = 0;
instance->memory(0)->atomicNotify(instance->module()->store(), address, count, &result);
return result;
Expand Down Expand Up @@ -1660,7 +1659,7 @@ static void emitAtomicNotify(sljit_compiler* compiler, Instruction* instr)
MOVE_TO_REG(compiler, SLJIT_MOV, SLJIT_R2, count.arg, count.argw);
}

sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, kContextReg, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, kInstanceReg, 0);
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS3(W, P, W, 32), SLJIT_IMM, GET_FUNC_ADDR(sljit_sw, atomicNotifyCallback));

MOVE_FROM_REG(compiler, SLJIT_MOV, dst.arg, dst.argw, SLJIT_R0);
Expand Down
Loading

0 comments on commit 4e655cb

Please sign in to comment.