Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Implement snapshot for interpreter and add tests #2

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions experimental/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ type Snapshot interface {
}

// Snapshotter allows host functions to snapshot the WebAssembly execution environment.
// Currently, only the Wasm stack is captured, but in the future, this may be expanded
// to things like globals.
type Snapshotter interface {
// Snapshot captures the current execution state.
Snapshot() Snapshot
Expand Down
100 changes: 100 additions & 0 deletions experimental/checkpoint_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package experimental_test

import (
"context"
_ "embed"
"fmt"
"log"

wazero "github.com/wasilibs/wazerox"
"github.com/wasilibs/wazerox/api"
"github.com/wasilibs/wazerox/experimental"
)

// snapshotWasm was generated by the following:
//
// cd testdata; wat2wasm snapshot.wat
//
//go:embed testdata/snapshot.wasm
var snapshotWasm []byte

type snapshotsKey struct{}

func Example_enableSnapshotterKey() {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx) // This closes everything this Runtime created.

// Enable experimental snapshotting functionality by setting it to context. We use this
// context when invoking functions, indicating to wazero to enable it.
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

// Also place a mutable holder of snapshots to be referenced during restore.
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)

// Register host functions using snapshot and restore. Generally snapshot is saved
// into a mutable location in context to be referenced during restore.
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
// Because we set EnableSnapshotterKey to context, this is non-nil.
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()

// Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot
// and restore once but real programs will often call them at multiple layers within a call
// stack with various e.g., try/catch statements.
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)

// Write a value to be passed back to restore. This is meant to be opaque to the guest
// and used to re-reference the snapshot.
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
if !ok {
log.Panicln("failed to write snapshot index")
}

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
// Read the value written by snapshot to re-reference the snapshot.
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
if !ok {
log.Panicln("failed to read snapshot index")
}

// Get the snapshot
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

// Restore! The invocation of this function will end as soon as we invoke
// Restore, so we also pass in our return value. The guest function run
// will finish with this return value.
snapshot.Restore([]uint64{5})
}).
Export("restore").
Instantiate(ctx)
if err != nil {
log.Panicln(err)
}

mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code
if err != nil {
log.Panicln(err)
}

// Call the guest entrypoint.
res, err := mod.ExportedFunction("run").Call(ctx)
if err != nil {
log.Panicln(err)
}
// We restored and returned the restore value, so it's our result. If restore
// was instead a no-op, we would have returned 10 from normal code flow.
fmt.Println(res[0])
// Output:
// 5
}
121 changes: 121 additions & 0 deletions experimental/checkpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package experimental_test

import (
"context"
"testing"

wazero "github.com/wasilibs/wazerox"
"github.com/wasilibs/wazerox/api"
"github.com/wasilibs/wazerox/experimental"
"github.com/wasilibs/wazerox/internal/testing/require"
)

func TestSnapshotNestedWasmInvocation(t *testing.T) {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

sidechannel := 0

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
defer func() {
sidechannel = 10
}()
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
require.NoError(t, err)

return 2
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// return value from restore
require.Equal(t, uint64(12), res[0])
// Host function defers within the call stack work fine
require.Equal(t, 10, sidechannel)
}

func TestSnapshotMultipleWasmInvocations(t *testing.T) {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// snapshot returned zero
require.Equal(t, uint64(0), res[0])

// Fails, snapshot and restore are called from different wasm invocations. Currently, this
// results in a panic.
err = require.CapturePanic(func() {
_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
})
require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
"exported function invocation than snapshot")
}
Binary file added experimental/testdata/snapshot.wasm
Binary file not shown.
34 changes: 34 additions & 0 deletions experimental/testdata/snapshot.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
(module
(import "example" "snapshot" (func $snapshot (param i32) (result i32)))
(import "example" "restore" (func $restore (param i32)))

(func $helper (result i32)
(call $restore (i32.const 0))
;; Not executed
i32.const 10
)

(func (export "run") (result i32) (local i32)
(call $snapshot (i32.const 0))
local.set 0
local.get 0
(if (result i32)
(then ;; restore return, finish with the value returned by it
local.get 0
)
(else ;; snapshot return, call heloer
(call $helper)
)
)
)

(func (export "snapshot") (param i32) (result i32)
(call $snapshot (local.get 0))
)

(func (export "restore") (param i32)
(call $restore (local.get 0))
)

(memory (export "memory") 1 1)
)
11 changes: 11 additions & 0 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,11 @@ func callFrameOffset(funcType *wasm.FunctionType) (ret int) {
//
// This is defined for testability.
func (ce *callEngine) deferredOnCall(ctx context.Context, m *wasm.ModuleInstance, recovered interface{}) (err error) {
if s, ok := recovered.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}
if recovered != nil {
builder := wasmdebug.NewErrorBuilder()

Expand Down Expand Up @@ -1260,6 +1265,12 @@ func (s *snapshot) doRestore() {
copy(ce.stack[s.hostBase:], s.ret)
}

// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}

// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64
Expand Down
77 changes: 76 additions & 1 deletion internal/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,53 @@ type function struct {
parent *compiledFunction
}

type snapshot struct {
stack []uint64
frames []*callFrame
pc uint64

ret []uint64

ce *callEngine
}

// Snapshot implements the same method as documented on experimental.Snapshotter.
func (ce *callEngine) Snapshot() experimental.Snapshot {
stack := make([]uint64, len(ce.stack))
copy(stack, ce.stack)

frames := make([]*callFrame, len(ce.frames))
copy(frames, ce.frames)

return &snapshot{
stack: stack,
frames: frames,
ce: ce,
}
}

// Restore implements the same method as documented on experimental.Snapshot.
func (s *snapshot) Restore(ret []uint64) {
s.ret = ret
panic(s)
}

func (s *snapshot) doRestore() {
ce := s.ce

ce.stack = s.stack
ce.frames = s.frames
ce.frames[len(ce.frames)-1].pc = s.pc

copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret)
}

// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}

// functionFromUintptr resurrects the original *function from the given uintptr
// which comes from either funcref table or OpcodeRefFunc instruction.
func functionFromUintptr(ptr uintptr) *function {
Expand Down Expand Up @@ -512,6 +559,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
}
}

if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
}

defer func() {
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
if err == nil {
Expand Down Expand Up @@ -555,6 +606,12 @@ type functionListenerInvocation struct {
// with the call frame stack traces. Also, reset the state of callEngine
// so that it can be used for the subsequent calls.
func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) {
if s, ok := v.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}

builder := wasmdebug.NewErrorBuilder()
frameCount := len(ce.frames)
functionListeners := make([]functionListenerInvocation, 0, 16)
Expand Down Expand Up @@ -669,7 +726,25 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance
ce.drop(op.Us[v+1])
frame.pc = op.Us[v]
case wazeroir.OperationKindCall:
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
func() {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(*snapshot); ok {
if s.ce == ce {
s.doRestore()
frame = ce.frames[len(ce.frames)-1]
body = frame.f.parent.body
bodyLen = uint64(len(body))
} else {
panic(r)
}
} else {
panic(r)
}
}
}()
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
}()
frame.pc++
case wazeroir.OperationKindCallIndirect:
offset := ce.popValue()
Expand Down
Loading