Skip to content

Commit

Permalink
Improve memory cache stack usage
Browse files Browse the repository at this point in the history
Create the cache only when necessary.

Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg committed Sep 18, 2024
1 parent 46dd399 commit c7b3a0f
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 48 deletions.
46 changes: 45 additions & 1 deletion src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ CompileContext::CompileContext(Module* module, JITCompiler* compiler)
, shuffleOffset(0)
#endif /* SLJIT_CONFIG_X86 */
, stackTmpStart(0)
, stackMemoryStart(0)
, nextTryBlock(0)
, currentTryBlock(InstanceConstData::globalTryBlock)
, trapBlocksStart(0)
Expand Down Expand Up @@ -1021,6 +1022,7 @@ JITCompiler::JITCompiler(Module* module, uint32_t JITFlags)
, m_savedIntegerRegCount(0)
, m_savedFloatRegCount(0)
, m_stackTmpSize(0)
, m_useMemory0(false)
{
if (module->m_jitModule != nullptr) {
ASSERT(module->m_jitModule->m_instanceConstData != nullptr);
Expand All @@ -1038,6 +1040,10 @@ void JITCompiler::compileFunction(JITFunction* jitFunc, bool isExternal)

m_functionList.push_back(FunctionList(jitFunc, isExternal, m_branchTableSize));

sljit_uw stackTmpStart = m_useMemory0 ? sizeof(Memory::TargetBuffer) : 0;
// Align data.
m_context.stackTmpStart = static_cast<sljit_sw>((stackTmpStart + sizeof(sljit_sw) - 1) & ~(sizeof(sljit_sw) - 1));

if (m_compiler == nullptr) {
// First compiled function.
m_compiler = sljit_create_compiler(nullptr);
Expand Down Expand Up @@ -1465,6 +1471,7 @@ void JITCompiler::clear()
m_last = nullptr;
m_branchTableSize = 0;
m_stackTmpSize = 0;
m_useMemory0 = false;
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
m_context.shuffleOffset = 0;
#endif /* SLJIT_CONFIG_X86 */
Expand Down Expand Up @@ -1509,7 +1516,27 @@ void JITCompiler::emitProlog()

sljit_emit_enter(m_compiler, options, SLJIT_ARGS0(P),
SLJIT_NUMBER_OF_SCRATCH_REGISTERS | SLJIT_ENTER_FLOAT(SLJIT_NUMBER_OF_SCRATCH_FLOAT_REGISTERS),
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), m_stackTmpSize);
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), m_context.stackTmpStart + m_stackTmpSize);

if (hasMemory0()) {
sljit_sw stackMemoryStart = m_context.stackMemoryStart;
ASSERT(m_context.stackTmpStart >= stackMemoryStart + static_cast<sljit_sw>(sizeof(Memory::TargetBuffer)));

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_R0), Instance::alignedSize());

sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_targetBuffers));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_sizeInByte) + WORD_LOW_OFFSET);
sljit_get_local_base(m_compiler, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_targetBuffers), stackMemoryStart);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_buffer));

#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_HIGH_OFFSET, SLJIT_IMM, 0);
#endif /* SLJIT_32BIT_ARCHITECTURE */
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, prev), SLJIT_R1, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET, SLJIT_R2, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, buffer), SLJIT_R0, 0);
}

m_context.branchTableOffset = 0;
size_t size = func.branchTableSize * sizeof(sljit_up);
Expand All @@ -1530,6 +1557,20 @@ void JITCompiler::emitProlog()
}
}

void JITCompiler::emitRestoreMemories()
{
if (!hasMemory0()) {
return;
}

sljit_sw stackMemoryStart = m_context.stackMemoryStart;

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(kContextReg), OffsetOfContextField(instance));
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, prev));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_R1), Instance::alignedSize());
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_R1), offsetof(Memory, m_targetBuffers), SLJIT_R2, 0);
}

void JITCompiler::emitEpilog()
{
FunctionList& func = m_functionList.back();
Expand All @@ -1547,6 +1588,7 @@ void JITCompiler::emitEpilog()
m_context.earlyReturns.clear();
}

emitRestoreMemories();
sljit_emit_return(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0);

m_context.emitSlowCases(m_compiler);
Expand Down Expand Up @@ -1607,6 +1649,8 @@ void JITCompiler::emitEpilog()
sljit_emit_op_dst(m_compiler, SLJIT_GET_RETURN_ADDRESS, SLJIT_R1, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, kContextReg, 0);
sljit_emit_icall(m_compiler, SLJIT_CALL, SLJIT_ARGS2(W, W, W), SLJIT_IMM, GET_FUNC_ADDR(sljit_sw, getTrapHandler));

emitRestoreMemories();
sljit_emit_return_to(m_compiler, SLJIT_R0, 0);

while (trapJumpIndex < trapJumps.size()) {
Expand Down
11 changes: 11 additions & 0 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,12 +1019,14 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::Load32Opcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDst;
compiler->useMemory0();
requiredInit = OTLoadI32;
break;
}
case ByteCode::Load64Opcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDst;
compiler->useMemory0();
requiredInit = OTLoadI64;
break;
}
Expand All @@ -1044,6 +1046,7 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64Load32UOpcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDstValue;
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTLoadI64;
}
Expand All @@ -1066,6 +1069,7 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::V128Load64ZeroOpcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDstValue;
compiler->useMemory0();

if (opcode == ByteCode::F32LoadOpcode)
requiredInit = OTLoadF32;
Expand All @@ -1082,6 +1086,7 @@ static void compileFunction(JITCompiler* compiler)
SIMDMemoryLoad* loadOperation = reinterpret_cast<SIMDMemoryLoad*>(byteCode);
Instruction* instr = compiler->append(byteCode, Instruction::LoadLaneSIMD, opcode, 2, 1);
instr->setRequiredRegsDescriptor(OTLoadLaneV128);
compiler->useMemory0();

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(loadOperation->src0Offset());
Expand All @@ -1092,12 +1097,14 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::Store32Opcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2;
compiler->useMemory0();
requiredInit = OTStoreI32;
break;
}
case ByteCode::Store64Opcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2;
compiler->useMemory0();
requiredInit = OTStoreI64;
break;
}
Expand All @@ -1117,6 +1124,7 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64StoreOpcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2Value;
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTStoreI64;
}
Expand All @@ -1127,6 +1135,7 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::V128StoreOpcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2Value;
compiler->useMemory0();

if (opcode == ByteCode::F32StoreOpcode)
requiredInit = OTStoreF32;
Expand All @@ -1143,6 +1152,7 @@ static void compileFunction(JITCompiler* compiler)
SIMDMemoryStore* storeOperation = reinterpret_cast<SIMDMemoryStore*>(byteCode);
Instruction* instr = compiler->append(byteCode, Instruction::Store, opcode, 2, 0);
instr->setRequiredRegsDescriptor(OTStoreV128);
compiler->useMemory0();

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(storeOperation->src0Offset());
Expand Down Expand Up @@ -1317,6 +1327,7 @@ static void compileFunction(JITCompiler* compiler)

Instruction* instr = compiler->append(byteCode, Instruction::Memory, opcode, 0, 1);
instr->setRequiredRegsDescriptor(OTPutI32);
compiler->useMemory0();

*instr->operands() = STACK_OFFSET(memorySize->dstOffset());
break;
Expand Down
13 changes: 13 additions & 0 deletions src/jit/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ struct CompileContext {
size_t tableStart;
size_t functionsStart;
sljit_sw stackTmpStart;
sljit_sw stackMemoryStart;
size_t nextTryBlock;
size_t currentTryBlock;
size_t trapBlocksStart;
Expand Down Expand Up @@ -756,6 +757,16 @@ class JITCompiler {
}
}

void useMemory0()
{
m_useMemory0 = true;
}

bool hasMemory0()
{
return m_useMemory0;
}

void setModuleFunction(ModuleFunction* moduleFunction)
{
m_moduleFunction = moduleFunction;
Expand Down Expand Up @@ -802,6 +813,7 @@ class JITCompiler {
// Backend operations.
void emitProlog();
void emitEpilog();
void emitRestoreMemories();

#if !defined(NDEBUG)
static const char* m_byteCodeNames[];
Expand All @@ -825,6 +837,7 @@ class JITCompiler {
uint8_t m_savedIntegerRegCount;
uint8_t m_savedFloatRegCount;
uint8_t m_stackTmpSize;
bool m_useMemory0;

std::vector<TryBlock> m_tryBlocks;
std::vector<FunctionList> m_functionList;
Expand Down
2 changes: 2 additions & 0 deletions src/jit/IntMath64Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

/* Only included by jit-backend.cc */

#define WORD_LOW_OFFSET 0

enum DivRemOptions : sljit_s32 {
DivRem32 = 1 << 1,
DivRemSigned = 1 << 0,
Expand Down
32 changes: 12 additions & 20 deletions src/jit/MemoryInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ struct MemAddress {
void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_uw offset, sljit_u32 size)
{
CompileContext* context = CompileContext::get(compiler);
sljit_sw stackMemoryStart = context->stackMemoryStart;

ASSERT(context->compiler->hasMemory0());
ASSERT(!(options & LoadInteger) || baseReg != sourceReg);
ASSERT(!(options & LoadInteger) || offsetReg != sourceReg);
#if defined(ENABLE_EXTENDED_FEATURES)
Expand Down Expand Up @@ -105,8 +107,8 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u

if (offset + size <= context->initialMemorySize) {
ASSERT(baseReg != 0);
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));
memArg.arg = SLJIT_MEM1(baseReg);
memArg.argw = offset;
load(compiler);
Expand All @@ -121,18 +123,13 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u
}

ASSERT(baseReg != 0 && offsetReg != 0);
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, sizeInByte));
#else /* !SLJIT_64BIT_ARCHITECTURE */
/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
#endif /* SLJIT_64BIT_ARCHITECTURE */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);

sljit_emit_op1(compiler, SLJIT_MOV, offsetReg, 0, SLJIT_IMM, static_cast<sljit_sw>(offset + size));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));

load(compiler);

Expand Down Expand Up @@ -164,19 +161,14 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u
sljit_emit_op1(compiler, SLJIT_MOV_U32, offsetReg, 0, offsetArg.arg, offsetArg.argw);

if (context->initialMemorySize != context->maximumMemorySize) {
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, sizeInByte));
#else /* !SLJIT_64BIT_ARCHITECTURE */
/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
#endif /* SLJIT_64BIT_ARCHITECTURE */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
offset += size;
}

sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kContextReg),
OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));

load(compiler);

Expand Down
6 changes: 4 additions & 2 deletions src/jit/MemoryUtilInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ static void emitMemory(sljit_compiler* compiler, Instruction* instr)
switch (opcode) {
case ByteCode::MemorySizeOpcode: {
ASSERT(!(instr->info() & Instruction::kIsCallback));
ASSERT(context->compiler->hasMemory0());

JITArg dstArg(params);

sljit_emit_op2(compiler, SLJIT_LSHR32, dstArg.arg, dstArg.argw,
SLJIT_MEM1(kContextReg), OffsetOfContextField(memory0) + offsetof(Memory::TargetBuffer, sizeInByte), SLJIT_IMM, 16);
/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op2(compiler, SLJIT_LSHR, dstArg.arg, dstArg.argw, SLJIT_MEM1(SLJIT_SP),
context->stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET, SLJIT_IMM, 16);
return;
}
case ByteCode::MemoryInitOpcode:
Expand Down
9 changes: 0 additions & 9 deletions src/runtime/JITExec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,8 @@ ByteCodeStackOffset* JITFunction::call(ExecutionState& state, Instance* instance
ExecutionContext context(m_module->instanceConstData(), state, instance);
Memory* memory0 = nullptr;

if (instance->module()->numberOfMemoryTypes() > 0) {
memory0 = instance->memory(0);
memory0->push(&context.memory0);
}

ByteCodeStackOffset* resultOffsets = m_module->exportCall()(&context, bp, m_exportEntry);

if (memory0 != nullptr) {
memory0->pop(&context.memory0);
}

if (context.error != ExecutionContext::NoError) {
switch (context.error) {
case ExecutionContext::CapturedException:
Expand Down
1 change: 0 additions & 1 deletion src/runtime/JITExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ struct ExecutionContext {
ExecutionState& state;
Instance* instance;
Exception* capturedException;
Memory::TargetBuffer memory0;
ErrorCodes error;
};

Expand Down
19 changes: 4 additions & 15 deletions src/runtime/Memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,23 @@ class Store;
class DataSegment;

class Memory : public Extern {
friend class JITCompiler;

public:
static const uint32_t s_memoryPageSize = 1024 * 64;

// Caching memory target for fast access.
struct TargetBuffer {
TargetBuffer()
: prev(nullptr)
, sizeInByte(0)
, buffer(nullptr)
, sizeInByte(0)
{
}

TargetBuffer* prev;
uint64_t sizeInByte;
uint8_t* buffer;
uint64_t sizeInByte;
};

static Memory* createMemory(Store* store, uint64_t initialSizeInByte, uint64_t maximumSizeInByte, bool isShared);
Expand Down Expand Up @@ -300,19 +302,6 @@ class Memory : public Extern {
void copy(ExecutionState& state, uint32_t dstStart, uint32_t srcStart, uint32_t size);
void fill(ExecutionState& state, uint32_t start, uint8_t value, uint32_t size);

inline void push(TargetBuffer* targetBuffer)
{
targetBuffer->prev = m_targetBuffers;
targetBuffer->sizeInByte = sizeInByte();
targetBuffer->buffer = buffer();
m_targetBuffers = targetBuffer;
}

inline void pop(TargetBuffer* targetBuffer)
{
m_targetBuffers = targetBuffer->prev;
}

inline bool checkAccess(uint32_t offset, uint32_t size, uint32_t addend = 0) const
{
return !UNLIKELY(!((uint64_t)offset + (uint64_t)addend + (uint64_t)size <= m_sizeInByte));
Expand Down
Loading

0 comments on commit c7b3a0f

Please sign in to comment.