Skip to content

Commit

Permalink
Optimize shuffle on x64
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg authored and ksh8281 committed Jun 13, 2024
1 parent 96989a5 commit 31ec7d6
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 65 deletions.
1 change: 0 additions & 1 deletion src/jit/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "Walrus.h"

#include "jit/Compiler.h"
#include "jit/SljitLir.h"

#include <set>

Expand Down
37 changes: 28 additions & 9 deletions src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "runtime/Table.h"
#include "runtime/Tag.h"
#include "jit/Compiler.h"
#include "jit/SljitLir.h"
#ifdef WALRUS_JITPERF
#include "jit/PerfDump.h"
#endif
Expand Down Expand Up @@ -216,6 +215,9 @@ class SlowCase {

CompileContext::CompileContext(Module* module, JITCompiler* compiler)
: compiler(compiler)
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
, shuffleOffset(0)
#endif /* SLJIT_CONFIG_X86 */
, nextTryBlock(0)
, currentTryBlock(InstanceConstData::globalTryBlock)
, trapBlocksStart(0)
Expand Down Expand Up @@ -1369,7 +1371,7 @@ void JITCompiler::generateCode()
it.jitFunc->m_exportEntry = reinterpret_cast<void*>(sljit_get_label_addr(it.exportEntryLabel));

if (it.branchTableSize > 0) {
sljit_up* branchList = reinterpret_cast<sljit_up*>(it.jitFunc->m_branchList);
sljit_up* branchList = reinterpret_cast<sljit_up*>(it.jitFunc->m_constData);
ASSERT(branchList != nullptr);

sljit_up* end = branchList + it.branchTableSize;
Expand Down Expand Up @@ -1447,6 +1449,9 @@ void JITCompiler::clear()
m_first = nullptr;
m_last = nullptr;
m_branchTableSize = 0;
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
m_context.shuffleOffset = 0;
#endif /* SLJIT_CONFIG_X86 */

while (item != nullptr) {
InstructionListItem* next = item->next();
Expand Down Expand Up @@ -1475,7 +1480,12 @@ void JITCompiler::emitProlog()
func.exportEntryLabel = sljit_emit_label(m_compiler);
}

sljit_emit_enter(m_compiler, SLJIT_ENTER_REG_ARG | SLJIT_ENTER_KEEP(2), SLJIT_ARGS0(P),
sljit_s32 options = SLJIT_ENTER_REG_ARG | SLJIT_ENTER_KEEP(2);
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
options |= SLJIT_ENTER_USE_VEX;
#endif /* !SLJIT_CONFIG_X86 */

sljit_emit_enter(m_compiler, options, SLJIT_ARGS0(P),
SLJIT_NUMBER_OF_SCRATCH_REGISTERS, m_savedIntegerRegCount + 2,
SLJIT_NUMBER_OF_SCRATCH_FLOAT_REGISTERS, m_savedFloatRegCount, sizeof(ExecutionContext::CallFrame));

Expand All @@ -1488,20 +1498,29 @@ void JITCompiler::emitProlog()
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), offsetof(ExecutionContext::CallFrame, prevFrame), SLJIT_R0, 0);

m_context.branchTableOffset = 0;
size_t size = func.branchTableSize * sizeof(sljit_up);
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
size += m_context.shuffleOffset;
#endif /* SLJIT_CONFIG_X86 */

if (func.branchTableSize > 0) {
void* branchList = malloc(func.branchTableSize * sizeof(sljit_up));
if (size > 0) {
void* constData = malloc(size);

func.jitFunc->m_branchList = branchList;
m_context.branchTableOffset = reinterpret_cast<uintptr_t>(branchList);
func.jitFunc->m_constData = constData;
m_context.branchTableOffset = reinterpret_cast<uintptr_t>(constData);

#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
// Requires 16 byte alignment.
m_context.shuffleOffset = (reinterpret_cast<uintptr_t>(constData) + size - m_context.shuffleOffset + 0xf) & ~(uintptr_t)0xf;
#endif /* SLJIT_CONFIG_X86 */
}
}

void JITCompiler::emitEpilog()
{
FunctionList& func = m_functionList.back();

ASSERT(m_context.branchTableOffset == reinterpret_cast<sljit_uw>(func.jitFunc->m_branchList) + func.branchTableSize * sizeof(sljit_sw));
ASSERT(m_context.branchTableOffset == reinterpret_cast<sljit_uw>(func.jitFunc->m_constData) + func.branchTableSize * sizeof(sljit_sw));
ASSERT(m_context.currentTryBlock == InstanceConstData::globalTryBlock);

if (!m_context.earlyReturns.empty()) {
Expand Down Expand Up @@ -1587,7 +1606,7 @@ void JITCompiler::emitEpilog()
}

if (func.branchTableSize > 0) {
sljit_label** branchList = reinterpret_cast<sljit_label**>(func.jitFunc->m_branchList);
sljit_label** branchList = reinterpret_cast<sljit_label**>(func.jitFunc->m_constData);
ASSERT(branchList != nullptr);

sljit_label** end = branchList + func.branchTableSize;
Expand Down
22 changes: 14 additions & 8 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "Walrus.h"

#include "jit/Compiler.h"
#include "jit/SljitLir.h"
#include "runtime/JITExec.h"
#include "runtime/Module.h"

Expand Down Expand Up @@ -253,13 +252,13 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)

#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)

#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL3(OTOp2V128, /* SSD */ V128 | NOTMP, V128 | TMP, V128 | TMP | S0) \
OL3(OTOp1V128Tmp, /* SDT */ V128 | NOTMP, V128 | TMP | S0, V128) \
OL4(OTOp2V128Tmp, /* SSDT */ V128 | NOTMP, V128 | TMP, V128 | TMP | S0, V128) \
OL3(OTOp2V128Rev, /* SSD */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1) \
OL4(OTShuffleV128, /* SSDT */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1, V128) \
OL3(OTShiftV128, /* SSD */ V128 | NOTMP, I32, V128 | TMP | S0) \
#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL3(OTOp2V128, /* SSD */ V128 | NOTMP, V128 | TMP, V128 | TMP | S0) \
OL3(OTOp1V128Tmp, /* SDT */ V128 | NOTMP, V128 | TMP | S0, V128) \
OL4(OTOp2V128Tmp, /* SSDT */ V128 | NOTMP, V128 | TMP, V128 | TMP | S0, V128) \
OL3(OTOp2V128Rev, /* SSD */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1) \
OL3(OTShuffleV128, /* SSDT */ V128 | NOTMP, V128 | NOTMP, V128 | TMP | S0) \
OL3(OTShiftV128, /* SSD */ V128 | NOTMP, I32, V128 | TMP | S0) \
OL4(OTShiftV128Tmp, /* SSDT */ V128 | NOTMP, I32, V128 | TMP | S0, V128)

// List of aliases.
Expand Down Expand Up @@ -1860,6 +1859,13 @@ static void compileFunction(JITCompiler* compiler)
operands[0] = STACK_OFFSET(shuffle->srcOffsets()[0]);
operands[1] = STACK_OFFSET(shuffle->srcOffsets()[1]);
operands[2] = STACK_OFFSET(shuffle->dstOffset());

#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
if (compiler->context().shuffleOffset == 0) {
compiler->context().shuffleOffset = 16 - sizeof(sljit_up);
}
compiler->context().shuffleOffset += 32;
#endif /* SLJIT_CONFIG_X86 */
break;
}
default: {
Expand Down
5 changes: 5 additions & 0 deletions src/jit/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define __WalrusJITCompiler__

#include "interpreter/ByteCode.h"
#include "jit/SljitLir.h"
#include "runtime/Module.h"

// Backend compiler structures.
Expand Down Expand Up @@ -548,6 +549,9 @@ struct CompileContext {

JITCompiler* compiler;
uintptr_t branchTableOffset;
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
uintptr_t shuffleOffset;
#endif /* SLJIT_CONFIG_X86 */
size_t globalsStart;
size_t tableStart;
size_t functionsStart;
Expand Down Expand Up @@ -690,6 +694,7 @@ class JITCompiler {

Module* module() { return m_module; }
ModuleFunction* moduleFunction() { return m_moduleFunction; }
CompileContext& context() { return m_context; }
uint32_t JITFlags() { return m_JITFlags; }
uint32_t options() { return m_options; }
InstructionListItem* first() { return m_first; }
Expand Down
1 change: 0 additions & 1 deletion src/jit/InstList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "Walrus.h"

#include "jit/Compiler.h"
#include "jit/SljitLir.h"
#include "runtime/ObjectType.h"

#include <map>
Expand Down
1 change: 0 additions & 1 deletion src/jit/PerfDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#if defined WALRUS_JITPERF
#include "jit/PerfDump.h"
#include "jit/SljitLir.h"

#include <fcntl.h>
#include <sys/mman.h>
Expand Down
1 change: 0 additions & 1 deletion src/jit/RegisterAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "Walrus.h"

#include "jit/Compiler.h"
#include "jit/SljitLir.h"

#include <set>

Expand Down
71 changes: 32 additions & 39 deletions src/jit/SimdX86Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1826,63 +1826,56 @@ static void emitSelectSIMD(sljit_compiler* compiler, Instruction* instr)
static void emitShuffleSIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
sljit_s32 tmp1 = instr->requiredReg(2);
sljit_s32 tmp2 = SLJIT_TMP_DEST_FREG;
const sljit_s32 type = SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_8;
I8X16Shuffle* shuffle = reinterpret_cast<I8X16Shuffle*>(instr->byteCode());
sljit_s32 type = SLJIT_SIMD_OP2_SHUFFLE | SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_8 | SLJIT_SIMD_MEM_ALIGNED_128;
CompileContext* context = CompileContext::get(compiler);
JITArg args[3];

args[2].set(operands + 2);
sljit_s32 dst = GET_TARGET_REG(args[2].arg, instr->requiredReg(1));
sljit_s32 dst = GET_TARGET_REG(args[2].arg, instr->requiredReg(0));

if (operands[0] == operands[1]) {
simdOperandToArg(compiler, operands, args[0], SLJIT_SIMD_ELEM_128, dst);
sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | type, SLJIT_TMP_DEST_FREG, SLJIT_MEM0(), reinterpret_cast<sljit_sw>(shuffle->value()));

if (args[0].arg != dst) {
if (sljit_has_cpu_feature(SLJIT_HAS_AVX)) {
simdEmitVexOp(compiler, SimdOp::pshufb, dst, args[1].arg, tmp1);
} else {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | type, dst, args[0].arg, 0);
args[1].arg = dst;
}
}

if (dst == args[0].arg) {
simdEmitSSEOp(compiler, SimdOp::pshufb, dst, SLJIT_TMP_DEST_FREG);
// Pre compute the offsets into an aligned buffer
const uint8_t* source = shuffle->value();
uint8_t* destination = reinterpret_cast<uint8_t*>(context->shuffleOffset);

for (size_t i = 0; i < 16; i++) {
*destination++ = (*source >= 16) ? (*source - 16) : (*source);
source++;
}
} else {
simdOperandToArg(compiler, operands, args[0], SLJIT_SIMD_ELEM_128, instr->requiredReg(0));
simdOperandToArg(compiler, operands + 1, args[1], SLJIT_SIMD_ELEM_128, instr->requiredReg(1));

sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | type, tmp1, SLJIT_MEM0(), reinterpret_cast<sljit_sw>(shuffle->value()));
sljit_emit_simd_op2(compiler, type, args[2].arg, args[0].arg, SLJIT_MEM0(), static_cast<sljit_sw>(context->shuffleOffset));
context->shuffleOffset += 16;
} else {
ASSERT(context->shuffleOffset > 0 && (context->shuffleOffset & 0xf) == 0);

sljit_emit_simd_replicate(compiler, SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_8, tmp2, SLJIT_IMM, 0xf0);
simdEmitSSEOp(compiler, SimdOp::paddb, tmp1, tmp2);
simdOperandToArg(compiler, operands, args[0], SLJIT_SIMD_ELEM_128, dst);
simdOperandToArg(compiler, operands + 1, args[1], SLJIT_SIMD_ELEM_128, SLJIT_TMP_DEST_FREG);

if (dst != args[1].arg) {
if (sljit_has_cpu_feature(SLJIT_HAS_AVX)) {
simdEmitVexOp(compiler, SimdOp::pshufb, dst, args[1].arg, tmp1);
} else {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | type, dst, args[1].arg, 0);
args[1].arg = dst;
}
}
// Pre compute the offsets into an aligned buffer
const uint8_t* source = shuffle->value();
uint8_t* destination = reinterpret_cast<uint8_t*>(context->shuffleOffset);

if (dst == args[1].arg) {
simdEmitSSEOp(compiler, SimdOp::pshufb, dst, tmp1);
for (size_t i = 0; i < 16; i++) {
*destination++ = (*source >= 16) ? (*source - 16) : 128;
source++;
}

simdEmitSSEOp(compiler, SimdOp::pxor, tmp1, tmp2);
source = shuffle->value();

if (sljit_has_cpu_feature(SLJIT_HAS_AVX)) {
simdEmitVexOp(compiler, SimdOp::pshufb, tmp2, args[0].arg, tmp1);
} else {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | type, tmp2, args[0].arg, 0);
simdEmitSSEOp(compiler, SimdOp::pshufb, tmp2, tmp1);
for (size_t i = 0; i < 16; i++) {
*destination++ = (*source < 16) ? *source : 128;
source++;
}

simdEmitSSEOp(compiler, SimdOp::por, dst, tmp2);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_MEM_REG, 0, SLJIT_IMM, static_cast<sljit_sw>(context->shuffleOffset));
context->shuffleOffset += 32;

sljit_emit_simd_op2(compiler, type, SLJIT_TMP_DEST_FREG, args[1].arg, SLJIT_MEM1(SLJIT_TMP_MEM_REG), 0);
sljit_emit_simd_op2(compiler, type, args[2].arg, args[0].arg, SLJIT_MEM1(SLJIT_TMP_MEM_REG), 16);
simdEmitSSEOp(compiler, SimdOp::por, args[2].arg, SLJIT_TMP_DEST_FREG);
}

if (SLJIT_IS_MEM(args[2].arg)) {
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/JITExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ class JITFunction {
public:
JITFunction()
: m_exportEntry(nullptr)
, m_branchList(nullptr)
, m_constData(nullptr)
, m_module(nullptr)
{
}

~JITFunction()
{
if (m_branchList != nullptr) {
free(m_branchList);
if (m_constData != nullptr) {
free(m_constData);
}
}

Expand All @@ -132,7 +132,7 @@ class JITFunction {

private:
void* m_exportEntry;
void* m_branchList;
void* m_constData;
JITModule* m_module;
};

Expand Down

0 comments on commit 31ec7d6

Please sign in to comment.