Skip to content

Commit

Permalink
Add metering to the VM (#49)
Browse files Browse the repository at this point in the history
* Add support for metering the virtual machine. The meter can be set when you invoke or resume the VM. In the preamble it will decrease the meter and if it runs out, return an error and save state so that the VM can resume from where it was interrupted.
* Includes running a unit test with metering to verify it works
* Fixed a memory leak related to function instances
  • Loading branch information
Southporter authored Jun 15, 2024
1 parent 4c59c40 commit 3cc31c0
Show file tree
Hide file tree
Showing 8 changed files with 618 additions and 462 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
- name: Run unit tests
run: |
zig build test-unit
zig build -Dmeter=true test-unit
- name: Run wasm testsuite
run: |
Expand Down
15 changes: 15 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ const ExeOpts = struct {
description: []const u8,
step_dependencies: ?[]*Build.Step = null,
should_emit_asm: bool = false,
options: *Build.Step.Options,
};

pub fn build(b: *Build) void {
const should_emit_asm = b.option(bool, "asm", "Emit asm for the bytebox binaries") orelse false;

const enable_metering = b.option(bool, "meter", "Enable metering") orelse false;

const options = b.addOptions();
options.addOption(bool, "enable_metering", enable_metering);

const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{});

Expand All @@ -37,6 +43,8 @@ pub fn build(b: *Build) void {
.imports = &[_]ModuleImport{stable_array_import},
});

bytebox_module.addOptions("config", options);

// exe.root_module.addImport(import.name, import.module);

const imports = [_]ModuleImport{
Expand All @@ -50,6 +58,7 @@ pub fn build(b: *Build) void {
.step_name = "run",
.description = "Run a wasm program",
.should_emit_asm = should_emit_asm,
.options = options,
});

var bench_steps = [_]*Build.Step{
Expand All @@ -63,6 +72,7 @@ pub fn build(b: *Build) void {
.step_name = "bench",
.description = "Run the benchmark suite",
.step_dependencies = &bench_steps,
.options = options,
});

const lib_bytebox: *Build.Step.Compile = b.addStaticLibrary(.{
Expand All @@ -72,6 +82,7 @@ pub fn build(b: *Build) void {
.optimize = optimize,
});
lib_bytebox.root_module.addImport(stable_array_import.name, stable_array_import.module);
lib_bytebox.root_module.addOptions("config", options);
lib_bytebox.installHeader(b.path("src/bytebox.h"), "bytebox.h");
b.installArtifact(lib_bytebox);

Expand All @@ -82,6 +93,7 @@ pub fn build(b: *Build) void {
.optimize = optimize,
});
unit_tests.root_module.addImport(stable_array_import.name, stable_array_import.module);
unit_tests.root_module.addOptions("config", options);
const run_unit_tests = b.addRunArtifact(unit_tests);
const unit_test_step = b.step("test-unit", "Run unit tests");
unit_test_step.dependOn(&run_unit_tests.step);
Expand All @@ -92,6 +104,7 @@ pub fn build(b: *Build) void {
.root_src = "test/wasm/main.zig",
.step_name = "test-wasm",
.description = "Run the wasm testsuite",
.options = options,
});

// wasi tests
Expand All @@ -109,6 +122,7 @@ pub fn build(b: *Build) void {
.root_src = "test/mem64/main.zig",
.step_name = "test-mem64",
.description = "Run the mem64 test",
.options = options,
});

// All tests
Expand All @@ -130,6 +144,7 @@ fn buildExeWithRunStep(b: *Build, target: Build.ResolvedTarget, optimize: std.bu
for (imports) |import| {
exe.root_module.addImport(import.name, import.module);
}
exe.root_module.addOptions("config", opts.options);

// exe.emit_asm = if (opts.should_emit_asm) .emit else .default;
b.installArtifact(exe);
Expand Down
2 changes: 1 addition & 1 deletion build.zig.zon
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.{
.name = "bytebox",
.version = "0.0.1",
.minimum_zig_version = "0.12.0",
.minimum_zig_version = "0.13.0",
.paths = .{
"src",
"test/mem64",
Expand Down
19 changes: 13 additions & 6 deletions src/instance.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ const AllocError = std.mem.Allocator.Error;

const builtin = @import("builtin");

const metering = @import("metering.zig");

const common = @import("common.zig");
const StableArray = common.StableArray;
const Logger = common.Logger;
Expand Down Expand Up @@ -46,6 +48,7 @@ pub const ExportError = error{

pub const TrapError = error{
TrapDebug,
TrapInvalidResume,
TrapUnreachable,
TrapIntegerDivisionByZero,
TrapIntegerOverflow,
Expand All @@ -57,7 +60,7 @@ pub const TrapError = error{
TrapOutOfBoundsTableAccess,
TrapStackExhausted,
TrapUnknown,
};
} || metering.MeteringTrapError;

pub const DebugTrace = struct {
pub const Mode = enum {
Expand Down Expand Up @@ -616,6 +619,10 @@ pub const ModuleInstantiateOpts = struct {

pub const InvokeOpts = struct {
trap_on_start: bool = false,
meter: metering.Meter = metering.initial_meter,
};
pub const ResumeInvokeOpts = struct {
meter: metering.Meter = metering.initial_meter,
};

pub const DebugTrapInstructionMode = enum {
Expand All @@ -629,7 +636,7 @@ pub const VM = struct {
const InstantiateFn = *const fn (vm: *VM, module: *ModuleInstance, opts: ModuleInstantiateOpts) anyerror!void;
const InvokeFn = *const fn (vm: *VM, module: *ModuleInstance, handle: FunctionHandle, params: [*]const Val, returns: [*]Val, opts: InvokeOpts) anyerror!void;
const InvokeWithIndexFn = *const fn (vm: *VM, module: *ModuleInstance, func_index: usize, params: [*]const Val, returns: [*]Val) anyerror!void;
const ResumeInvokeFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void;
const ResumeInvokeFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void;
const StepFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void;
const SetDebugTrapFn = *const fn (vm: *VM, module: *ModuleInstance, wasm_address: u32, mode: DebugTrapInstructionMode) anyerror!bool;
const FormatBacktraceFn = *const fn (vm: *VM, indent: u8, allocator: std.mem.Allocator) anyerror!std.ArrayList(u8);
Expand Down Expand Up @@ -699,8 +706,8 @@ pub const VM = struct {
try vm.invoke_with_index_fn(vm, module, func_index, params, returns);
}

pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
try vm.resume_invoke_fn(vm, module, returns);
pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
try vm.resume_invoke_fn(vm, module, returns, opts);
}

pub fn step(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
Expand Down Expand Up @@ -1186,8 +1193,8 @@ pub const ModuleInstance = struct {
}

/// Use to resume an invoked function after it returned error.DebugTrap
pub fn resumeInvoke(self: *ModuleInstance, returns: []Val) anyerror!void {
try self.vm.resumeInvoke(self, returns);
pub fn resumeInvoke(self: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
try self.vm.resumeInvoke(self, returns, opts);
}

pub fn step(self: *ModuleInstance, returns: []Val) anyerror!void {
Expand Down
20 changes: 20 additions & 0 deletions src/metering.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const config = @import("config");
const Instruction = @import("definition.zig").Instruction;

pub const enabled = config.enable_metering;

pub const Meter = if (enabled) usize else void;

pub const initial_meter = if (enabled) 0 else {};

pub const MeteringTrapError = if (enabled) error{TrapMeterExceeded} else error{};

pub fn reduce(fuel: Meter, instruction: Instruction) Meter {
if (fuel == 0) {
return fuel;
}
return switch (instruction.opcode) {
.Invalid, .Unreachable, .DebugTrap, .Noop, .Block, .Loop, .If, .IfNoElse, .Else, .End, .Branch, .Branch_If, .Branch_Table, .Drop => fuel,
else => fuel - 1,
};
}
45 changes: 45 additions & 0 deletions src/tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ const core = @import("core.zig");
const Limits = core.Limits;
const MemoryInstance = core.MemoryInstance;

const metering = @import("metering.zig");

test "StackVM.Integration" {
const wasm_filepath = "zig-out/bin/mandelbrot.wasm";

Expand All @@ -27,6 +29,49 @@ test "StackVM.Integration" {
defer module_inst.destroy();
}

test "StackVM.Metering" {
if (!metering.enabled) {
return;
}
const wasm_filepath = "zig-out/bin/fibonacci.wasm";

var allocator = std.testing.allocator;

var cwd = std.fs.cwd();
const wasm_data: []u8 = try cwd.readFileAlloc(allocator, wasm_filepath, 1024 * 1024 * 128);
defer allocator.free(wasm_data);

const module_def_opts = core.ModuleDefinitionOpts{
.debug_name = std.fs.path.basename(wasm_filepath),
};
var module_def = try core.createModuleDefinition(allocator, module_def_opts);
defer module_def.destroy();

try module_def.decode(wasm_data);

var module_inst = try core.createModuleInstance(.Stack, module_def, allocator);
defer module_inst.destroy();

try module_inst.instantiate(.{});

var returns = [1]core.Val{.{ .I64 = 5555 }};
var params = [1]core.Val{.{ .I32 = 10 }};

const handle = try module_inst.getFunctionHandle("run");
const res = module_inst.invoke(handle, &params, &returns, .{
.meter = 2,
});
try std.testing.expectError(metering.MeteringTrapError.TrapMeterExceeded, res);
try std.testing.expectEqual(5555, returns[0].I32);

const res2 = module_inst.resumeInvoke(&returns, .{ .meter = 5 });
try std.testing.expectError(metering.MeteringTrapError.TrapMeterExceeded, res2);
try std.testing.expectEqual(5555, returns[0].I32);

try module_inst.resumeInvoke(&returns, .{ .meter = 10000 });
try std.testing.expectEqual(89, returns[0].I32);
}

test "MemoryInstance.init" {
{
const limits = Limits{
Expand Down
4 changes: 3 additions & 1 deletion src/vm_register.zig
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const inst = @import("instance.zig");
const VM = inst.VM;
const ModuleInstance = inst.ModuleInstance;
const InvokeOpts = inst.InvokeOpts;
const ResumeInvokeOpts = inst.ResumeInvokeOpts;
const DebugTrapInstructionMode = inst.DebugTrapInstructionMode;
const ModuleInstantiateOpts = inst.ModuleInstantiateOpts;

Expand Down Expand Up @@ -1131,10 +1132,11 @@ pub const RegisterVM = struct {
return error.Unimplemented;
}

pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
_ = vm;
_ = module;
_ = returns;
_ = opts;
return error.Unimplemented;
}

Expand Down
Loading

0 comments on commit 3cc31c0

Please sign in to comment.