Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Use lock when reading cur value in interpreter wait
Browse files Browse the repository at this point in the history
  • Loading branch information
anuraaga committed Jan 13, 2024
1 parent 7545578 commit fe7b8d3
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 41 deletions.
6 changes: 4 additions & 2 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,9 @@ func (ce *callEngine) builtinFunctionMemoryWait32(mem *wasm.MemoryInstance) {
base := uintptr(unsafe.Pointer(&mem.Buffer[0]))

offset := uint32(addr - base)
cur := atomic.LoadUint32((*uint32)(unsafe.Pointer(addr)))

ce.pushValue(mem.Wait32(offset, exp, timeout))
ce.pushValue(mem.Wait(offset, uint64(cur), uint64(exp), timeout))
}

func (ce *callEngine) builtinFunctionMemoryWait64(mem *wasm.MemoryInstance) {
Expand All @@ -1205,8 +1206,9 @@ func (ce *callEngine) builtinFunctionMemoryWait64(mem *wasm.MemoryInstance) {
base := uintptr(unsafe.Pointer(&mem.Buffer[0]))

offset := uint32(addr - base)
cur := atomic.LoadUint64((*uint64)(unsafe.Pointer(addr)))

ce.pushValue(mem.Wait64(offset, exp, timeout))
ce.pushValue(mem.Wait(offset, cur, exp, timeout))
}

func (ce *callEngine) builtinFunctionMemoryNotify(mem *wasm.MemoryInstance) {
Expand Down
19 changes: 14 additions & 5 deletions internal/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3971,21 +3971,30 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance
if !memoryInst.Shared {
panic(wasmruntime.ErrRuntimeExpectedSharedMemory)
}
if int(offset) >= len(memoryInst.Buffer) {
panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess)
}

switch wazeroir.UnsignedType(op.B1) {
case wazeroir.UnsignedTypeI32:
if offset%4 != 0 {
panic(wasmruntime.ErrRuntimeUnalignedAtomic)
}
ce.pushValue(memoryInst.Wait32(offset, uint32(exp), timeout))
memoryInst.Mux.Lock()
cur, ok := memoryInst.ReadUint32Le(offset)
memoryInst.Mux.Unlock()
if !ok {
panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess)
}
ce.pushValue(memoryInst.Wait(offset, uint64(cur), uint64(uint32(exp)), timeout))
case wazeroir.UnsignedTypeI64:
if offset%8 != 0 {
panic(wasmruntime.ErrRuntimeUnalignedAtomic)
}
ce.pushValue(memoryInst.Wait64(offset, exp, timeout))
memoryInst.Mux.Lock()
cur, ok := memoryInst.ReadUint64Le(offset)
memoryInst.Mux.Unlock()
if !ok {
panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess)
}
ce.pushValue(memoryInst.Wait(offset, cur, exp, timeout))
}
frame.pc++
case wazeroir.OperationKindAtomicMemoryNotify:
Expand Down
6 changes: 5 additions & 1 deletion internal/integration_test/engine/threads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ func TestThreadsNotEnabled(t *testing.T) {
require.EqualError(t, err, "section memory: shared memory requested but threads feature not enabled")
}

func TestThreads(t *testing.T) {
func TestThreadsCompiler_hammer(t *testing.T) {
runAllTests(t, threadTests, wazero.NewRuntimeConfig().WithCoreFeatures(api.CoreFeaturesV2|experimental.CoreFeaturesThreads), false)
}

func TestThreadsInterpreter_hammer(t *testing.T) {
runAllTests(t, threadTests, wazero.NewRuntimeConfigInterpreter().WithCoreFeatures(api.CoreFeaturesV2|experimental.CoreFeaturesThreads), false)
}

func incrementGuardedByMutex(t *testing.T, r wazero.Runtime) {
P := 8 // max count of goroutines
if testing.Short() { // Adjust down if `-test.short`
Expand Down
27 changes: 2 additions & 25 deletions internal/wasm/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,37 +344,14 @@ func (m *MemoryInstance) writeUint64Le(offset uint32, v uint64) bool {
return true
}

// Wait32 suspends the caller until the offset is notified by a different agent.
func (m *MemoryInstance) Wait32(offset uint32, exp uint32, timeout int64) uint64 {
w := m.getWaiters(offset)
w.mux.Lock()

addr := unsafe.Add(unsafe.Pointer(&m.Buffer[0]), offset)
cur := atomic.LoadUint32((*uint32)(addr))
// Wait suspends the caller until the offset is notified by a different agent.
func (m *MemoryInstance) Wait(offset uint32, cur uint64, exp uint64, timeout int64) uint64 {
if cur != exp {
w.mux.Unlock()
return 1
}

return m.wait(w, timeout)
}

// Wait64 suspends the caller until the offset is notified by a different agent.
func (m *MemoryInstance) Wait64(offset uint32, exp uint64, timeout int64) uint64 {
w := m.getWaiters(offset)
w.mux.Lock()

addr := unsafe.Add(unsafe.Pointer(&m.Buffer[0]), offset)
cur := atomic.LoadUint64((*uint64)(addr))
if cur != exp {
w.mux.Unlock()
return 1
}

return m.wait(w, timeout)
}

func (m *MemoryInstance) wait(w *waiters, timeout int64) uint64 {
if w.l == nil {
w.l = list.New()
}
Expand Down
16 changes: 8 additions & 8 deletions internal/wasm/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ func TestMemoryInstance_WaitNotifyOnce(t *testing.T) {
// Reuse same offset 3 times to verify reuse
for i := 0; i < 3; i++ {
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()

Expand All @@ -830,11 +830,11 @@ func TestMemoryInstance_WaitNotifyOnce(t *testing.T) {

ch := make(chan string)
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()

Expand All @@ -850,11 +850,11 @@ func TestMemoryInstance_WaitNotifyOnce(t *testing.T) {

ch := make(chan string)
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()

Expand All @@ -871,11 +871,11 @@ func TestMemoryInstance_WaitNotifyOnce(t *testing.T) {

ch := make(chan string)
go func() {
res := mem.Wait32(0, 0, -1)
res := mem.Wait(0, 0, 0, -1)
propagateWaitResult(t, ch, res)
}()
go func() {
res := mem.Wait32(1, 268435456, -1)
res := mem.Wait(1, 268435456, 268435456, -1)
propagateWaitResult(t, ch, res)
}()

Expand All @@ -892,7 +892,7 @@ func TestMemoryInstance_WaitNotifyOnce(t *testing.T) {

ch := make(chan string)
go func() {
res := mem.Wait32(0, 0, 10 /* ns */)
res := mem.Wait(0, 0, 0, 10 /* ns */)
propagateWaitResult(t, ch, res)
}()

Expand Down

0 comments on commit fe7b8d3

Please sign in to comment.