From 9313f7b53fd10da43694d1eed302f079c806a7da Mon Sep 17 00:00:00 2001 From: Reuben Dunnington Date: Sat, 23 Mar 2024 12:44:39 -0700 Subject: [PATCH] wip: fleshing out function instantiation --- src/definition.zig | 20 +- src/instance.zig | 4 +- src/vm_register.zig | 706 +++++++++++++++++++++++++++++--------------- src/vm_stack.zig | 14 +- 4 files changed, 484 insertions(+), 260 deletions(-) diff --git a/src/definition.zig b/src/definition.zig index f80ed0e..3270cd8 100644 --- a/src/definition.zig +++ b/src/definition.zig @@ -539,10 +539,10 @@ pub const FunctionTypeDefinition = struct { }; pub const FunctionDefinition = struct { - type_index: u32, - instructions_begin: u32, - instructions_end: u32, - continuation: u32, + type_index: usize, + instructions_begin: usize, + instructions_end: usize, + continuation: usize, locals: std.ArrayList(ValType), // TODO use a slice of a large contiguous array instead pub fn instructions(func: FunctionDefinition, module_def: ModuleDefinition) []Instruction { @@ -1253,7 +1253,7 @@ const CustomSection = struct { pub const NameCustomSection = struct { const NameAssoc = struct { name: []const u8, - func_index: u32, + func_index: usize, fn cmp(_: void, a: NameAssoc, b: NameAssoc) bool { return a.func_index < b.func_index; @@ -1355,7 +1355,7 @@ pub const NameCustomSection = struct { return self.module_name; } - pub fn findFunctionName(self: *const NameCustomSection, func_index: u32) []const u8 { + pub fn findFunctionName(self: *const NameCustomSection, func_index: usize) []const u8 { if (func_index < self.function_names.items.len) { if (self.function_names.items[func_index].func_index == func_index) { return self.function_names.items[func_index].name; @@ -1629,7 +1629,7 @@ const ModuleValidator = struct { frame.is_unreachable = true; } - fn popPushFuncTypes(validator: *ModuleValidator, type_index: u32, module_: *const ModuleDefinition) !void { + fn popPushFuncTypes(validator: *ModuleValidator, type_index: usize, module_: *const ModuleDefinition) !void { const func_type: *const FunctionTypeDefinition = &module_.types.items[type_index]; try popReturnTypes(validator, func_type.getParams()); @@ -1740,7 +1740,7 @@ const ModuleValidator = struct { return error.ValidationUnknownFunction; } - var type_index: u32 = module.getFuncTypeIndex(func_index); + var type_index: usize = module.getFuncTypeIndex(func_index); try Helpers.popPushFuncTypes(self, type_index, module); }, .Call_Indirect => { @@ -3166,7 +3166,7 @@ pub const ModuleDefinition = struct { func_def.instructions_begin = @intCast(instructions.items.len); try block_stack.append(BlockData{ - .begin_index = func_def.instructions_begin, + .begin_index = @intCast(func_def.instructions_begin), .opcode = .Block, }); @@ -3543,7 +3543,7 @@ pub const ModuleDefinition = struct { } } - fn getFuncTypeIndex(self: *const ModuleDefinition, func_index: usize) u32 { + fn getFuncTypeIndex(self: *const ModuleDefinition, func_index: usize) usize { if (func_index < self.imports.functions.items.len) { const func_def: *const FunctionImportDefinition = &self.imports.functions.items[func_index]; return func_def.type_index; diff --git a/src/instance.zig b/src/instance.zig index 52da09d..9b37a56 100644 --- a/src/instance.zig +++ b/src/instance.zig @@ -99,9 +99,9 @@ pub const DebugTrace = struct { } } - pub fn traceFunction(module_instance: *const ModuleInstance, indent: u32, func_index: u32) void { + pub fn traceFunction(module_instance: *const ModuleInstance, indent: u32, func_index: usize) void { if (shouldTraceFunctions()) { - const func_name_index: u32 = func_index + @as(u32, @intCast(module_instance.module_def.imports.functions.items.len)); + const func_name_index: usize = func_index + module_instance.module_def.imports.functions.items.len; const name_section: *const NameCustomSection = &module_instance.module_def.name_section; const module_name = name_section.getModuleName(); diff --git a/src/vm_register.zig b/src/vm_register.zig index 6647cfa..b988692 100644 --- a/src/vm_register.zig +++ b/src/vm_register.zig @@ -53,6 +53,7 @@ const ValType = def.ValType; const TaggedVal = def.TaggedVal; const inst = @import("instance.zig"); +const TrapError = inst.TrapError; const VM = inst.VM; const ModuleInstance = inst.ModuleInstance; const InvokeOpts = inst.InvokeOpts; @@ -77,10 +78,10 @@ const IRNode = struct { edges_out: ?[*]*IRNode, edges_out_count: u32, - fn createWithInstruction(mir: *ModuleIR, instruction_index: u32) AllocError!*IRNode { - var node: *IRNode = mir.ir.addOne() catch return AllocError.OutOfMemory; + fn createWithInstruction(compiler: *FunctionCompiler, instruction_index: u32) AllocError!*IRNode { + var node: *IRNode = compiler.ir.addOne() catch return AllocError.OutOfMemory; node.* = IRNode{ - .opcode = mir.module_def.code.instructions.items[instruction_index].opcode, + .opcode = compiler.module_def.code.instructions.items[instruction_index].opcode, .is_phi = false, .instruction_index = instruction_index, .edges_in = null, @@ -91,8 +92,8 @@ const IRNode = struct { return node; } - fn createStandalone(mir: *ModuleIR, opcode: Opcode) AllocError!*IRNode { - var node: *IRNode = mir.ir.addOne() catch return AllocError.OutOfMemory; + fn createStandalone(compiler: *FunctionCompiler, opcode: Opcode) AllocError!*IRNode { + var node: *IRNode = compiler.ir.addOne() catch return AllocError.OutOfMemory; node.* = IRNode{ .opcode = opcode, .is_phi = false, @@ -105,8 +106,8 @@ const IRNode = struct { return node; } - fn createPhi(mir: *ModuleIR) AllocError!*IRNode { - var node: *IRNode = mir.ir.addOne() catch return AllocError.OutOfMemory; + fn createPhi(compiler: *FunctionCompiler) AllocError!*IRNode { + var node: *IRNode = compiler.ir.addOne() catch return AllocError.OutOfMemory; node.* = IRNode{ .opcode = .Invalid, .is_phi = true, @@ -302,37 +303,27 @@ const RegisterSlots = struct { } }; -const IRFunction = struct { - definition_index: usize, - ir_root: *IRNode, +const FunctionIR = struct { + def_index: usize = 0, + type_def_index: usize = 0, + ir_root: ?*IRNode = null, - register_map: std.AutoHashMap(*const IRNode, u32), - - fn init(definition_index: u32, ir_root: *IRNode, allocator: std.mem.Allocator) IRFunction { - return IRFunction{ - .definition_index = definition_index, - .ir_root = ir_root, - .register_map = std.AutoHashMap(*const IRNode, u32).init(allocator), - }; - } + // fn definition(func: FunctionIR, module_def: ModuleDefinition) *FunctionDefinition { + // return &module_def.functions.items[func.def_index]; + // } - fn deinit(self: *IRFunction) void { - self.register_map.deinit(); - } - - fn definition(func: IRFunction, module_def: ModuleDefinition) *FunctionDefinition { - return &module_def.functions.items[func.definition_index]; - } + fn regalloc(func: *FunctionIR, compile_data: *IntermediateCompileData, allocator: std.mem.Allocator) AllocError!void { + std.debug.assert(func.ir_root != null); - fn regalloc(func: *IRFunction, allocator: std.mem.Allocator) AllocError!void { - std.debug.assert(func.ir_root.opcode == .Return); // TODO need to update other places in the code to ensure this is a thing + var ir_root = func.ir_root.?; + std.debug.assert(ir_root.opcode == .Return); // TODO need to update other places in the code to ensure this is a thing var slots = RegisterSlots.init(allocator); defer slots.deinit(); var visit_queue = std.ArrayList(*IRNode).init(allocator); defer visit_queue.deinit(); - try visit_queue.append(func.ir_root); + try visit_queue.append(ir_root); var visited = std.AutoHashMap(*IRNode, void).init(allocator); defer visited.deinit(); @@ -344,7 +335,7 @@ const IRFunction = struct { // mark output node slots as free - this is safe because the dataflow graph flows one way and the // output can't be reused higher up in the graph for (node.edgesOut()) |output_node| { - if (func.register_map.get(output_node)) |index| { + if (compile_data.register_map.get(output_node)) |index| { slots.freeAt(output_node, index); } } @@ -353,7 +344,7 @@ const IRFunction = struct { // TODO handle multiple output slots (e.g. results of a function call) if (node.needsRegisterSlot()) { const index: u32 = try slots.alloc(node); - try func.register_map.put(node, index); + try compile_data.register_map.put(node, index); } // add inputs to the FIFO visit queue @@ -365,17 +356,20 @@ const IRFunction = struct { } } - fn codegen(func: *IRFunction, instructions: *std.ArrayList(RegInstruction), module_def: ModuleDefinition, allocator: std.mem.Allocator) AllocError!void { - // walk the graph in breadth-first order + // TODO call this from the compiler compile function, have the compile function take instructions and local_types arrays passed down from module instantiate + // TODO inline regalloc into this function + fn codegen(func: FunctionIR, store: *FunctionStore, compile_data: IntermediateCompileData, module_def: ModuleDefinition, allocator: std.mem.Allocator) AllocError!void { + std.debug.assert(func.ir_root != null); + // walk the graph in breadth-first order, starting from the last Return node // when a node is visited, emit its instruction // reverse the instructions array when finished (alternatively just emit in reverse order if we have the node count from regalloc) - const start_instruction_offset = instructions.items.len; + const start_instruction_offset = store.instructions.items.len; var visit_queue = std.ArrayList(*IRNode).init(allocator); defer visit_queue.deinit(); - try visit_queue.append(func.ir_root); + try visit_queue.append(func.ir_root.?); var visited = std.AutoHashMap(*IRNode, void).init(allocator); defer visited.deinit(); @@ -396,8 +390,8 @@ const IRFunction = struct { if (all_out_edges_visited) { try visited.put(node, {}); - instructions.append(RegInstruction{ - .registerSlotOffset = if (func.register_map.get(node)) |slot_index| slot_index else 0, + try store.instructions.append(RegInstruction{ + .registerSlotOffset = if (compile_data.register_map.get(node)) |slot_index| slot_index else 0, .opcode = node.opcode, .immediate = node.instruction(module_def).?.immediate, }); @@ -410,13 +404,32 @@ const IRFunction = struct { } } - const end_instruction_offset = instructions.items.len; - var emitted_instructions = instructions.items[start_instruction_offset..end_instruction_offset]; + const end_instruction_offset = store.instructions.items.len; + var emitted_instructions = store.instructions.items[start_instruction_offset..end_instruction_offset]; std.mem.reverse(RegInstruction, emitted_instructions); + + const func_def: *const FunctionDefinition = &module_def.functions.items[func.def_index]; + const func_type: *const FunctionTypeDefinition = &module_def.types.items[func.type_def_index]; + const param_types: []const ValType = func_type.getParams(); + try store.local_types.ensureTotalCapacity(store.local_types.items.len + param_types.len + func_def.locals.items.len); + + const types_index_begin = store.local_types.items.len; + store.local_types.appendSliceAssumeCapacity(param_types); + store.local_types.appendSliceAssumeCapacity(func_def.locals.items); + const types_index_end = store.local_types.items.len; + + try store.instances.append(FunctionInstance{ + .type_def_index = func.type_def_index, + .def_index = func.def_index, + .instructions_begin = start_instruction_offset, + .instructions_end = end_instruction_offset, + .local_types_begin = types_index_begin, + .local_types_end = types_index_end, + }); } - fn dumpVizGraph(func: IRFunction, path: []u8, module_def: ModuleDefinition, allocator: std.mem.Allocator) !void { + fn dumpVizGraph(func: FunctionIR, path: []u8, module_def: ModuleDefinition, allocator: std.mem.Allocator) !void { var graph_txt = std.ArrayList(u8).init(allocator); defer graph_txt.deinit(); try graph_txt.ensureTotalCapacity(1024 * 16); @@ -483,11 +496,18 @@ const IRFunction = struct { } }; -const ModuleIR = struct { +const IntermediateCompileData = struct { + const UniqueValueToIRNodeMap = std.HashMap(TaggedVal, *IRNode, TaggedVal.HashMapContext, std.hash_map.default_max_load_percentage); + + const PendingContinuationEdge = struct { + continuation: usize, + node: *IRNode, + }; + const BlockStack = struct { const Block = struct { node_start_index: u32, - continuation: u32, // in instruction index space + continuation: usize, // in instruction index space phi_nodes: []*IRNode, }; @@ -513,7 +533,7 @@ const ModuleIR = struct { self.blocks.deinit(); } - fn pushBlock(self: *BlockStack, continuation: u32) AllocError!void { + fn pushBlock(self: *BlockStack, continuation: usize) AllocError!void { try self.blocks.append(Block{ .node_start_index = @intCast(self.nodes.items.len), .continuation = continuation, @@ -558,193 +578,190 @@ const ModuleIR = struct { } }; - const IntermediateCompileData = struct { - const UniqueValueToIRNodeMap = std.HashMap(TaggedVal, *IRNode, TaggedVal.HashMapContext, std.hash_map.default_max_load_percentage); - - const PendingContinuationEdge = struct { - continuation: u32, - node: *IRNode, - }; + allocator: std.mem.Allocator, - allocator: std.mem.Allocator, + // all_nodes: std.ArrayList(*IRNode), - // all_nodes: std.ArrayList(*IRNode), + blocks: BlockStack, - blocks: BlockStack, + // This stack is a record of the nodes to push values onto the stack. If an instruction would push + // multiple values onto the stack, it would be in this list as many times as values it pushed. Note + // that we don't have to do any type checking here because the module has already been validated. + value_stack: std.ArrayList(*IRNode), - // This stack is a record of the nodes to push values onto the stack. If an instruction would push - // multiple values onto the stack, it would be in this list as many times as values it pushed. Note - // that we don't have to do any type checking here because the module has already been validated. - value_stack: std.ArrayList(*IRNode), + // records the current block continuation + // label_continuations: std.ArrayList(u32), - // records the current block continuation - // label_continuations: std.ArrayList(u32), + pending_continuation_edges: std.ArrayList(PendingContinuationEdge), - pending_continuation_edges: std.ArrayList(PendingContinuationEdge), + // when hitting an unconditional control transfer, we need to mark the rest of the stack values as unreachable just like in validation + is_unreachable: bool, - // when hitting an unconditional control transfer, we need to mark the rest of the stack values as unreachable just like in validation - is_unreachable: bool, + // This is a bit weird - since the Local_* instructions serve to just manipulate the locals into the stack, + // we need a way to represent what's in the locals slot as an SSA node. This array lets us do that. We also + // reuse the Local_Get instructions to indicate the "initial value" of the slot. Since our IRNode only stores + // indices to instructions, we'll just lazily set these when they're fetched for the first time. + locals: std.ArrayList(?*IRNode), - // This is a bit weird - since the Local_* instructions serve to just manipulate the locals into the stack, - // we need a way to represent what's in the locals slot as an SSA node. This array lets us do that. We also - // reuse the Local_Get instructions to indicate the "initial value" of the slot. Since our IRNode only stores - // indices to instructions, we'll just lazily set these when they're fetched for the first time. - locals: std.ArrayList(?*IRNode), + // Lets us collapse multiple const IR nodes with the same type/value into a single one + unique_constants: UniqueValueToIRNodeMap, - // Lets us collapse multiple const IR nodes with the same type/value into a single one - unique_constants: UniqueValueToIRNodeMap, - - scratch_node_list_1: std.ArrayList(*IRNode), - scratch_node_list_2: std.ArrayList(*IRNode), - - fn init(allocator: std.mem.Allocator) IntermediateCompileData { - return IntermediateCompileData{ - .allocator = allocator, - // .all_nodes = std.ArrayList(*IRNode).init(allocator), - .blocks = BlockStack.init(allocator), - .value_stack = std.ArrayList(*IRNode).init(allocator), - // .label_continuations = std.ArrayList(u32).init(allocator), - .pending_continuation_edges = std.ArrayList(PendingContinuationEdge).init(allocator), - .is_unreachable = false, - .locals = std.ArrayList(?*IRNode).init(allocator), - .unique_constants = UniqueValueToIRNodeMap.init(allocator), - .scratch_node_list_1 = std.ArrayList(*IRNode).init(allocator), - .scratch_node_list_2 = std.ArrayList(*IRNode).init(allocator), - }; - } + // + register_map: std.AutoHashMap(*const IRNode, u32), - fn warmup(self: *IntermediateCompileData, func_def: FunctionDefinition, module_def: ModuleDefinition) AllocError!void { - try self.locals.appendNTimes(null, func_def.numParamsAndLocals(module_def)); - try self.scratch_node_list_1.ensureTotalCapacity(4096); - try self.scratch_node_list_2.ensureTotalCapacity(4096); - // try self.label_continuations.append(func_def.continuation); - self.is_unreachable = false; - } + scratch_node_list_1: std.ArrayList(*IRNode), + scratch_node_list_2: std.ArrayList(*IRNode), - fn reset(self: *IntermediateCompileData) void { - // self.all_nodes.clearRetainingCapacity(); - self.blocks.reset(); - self.value_stack.clearRetainingCapacity(); - // self.label_continuations.clearRetainingCapacity(); - self.pending_continuation_edges.clearRetainingCapacity(); - self.locals.clearRetainingCapacity(); - self.unique_constants.clearRetainingCapacity(); - self.scratch_node_list_1.clearRetainingCapacity(); - self.scratch_node_list_2.clearRetainingCapacity(); - } + fn init(allocator: std.mem.Allocator) IntermediateCompileData { + return IntermediateCompileData{ + .allocator = allocator, + // .all_nodes = std.ArrayList(*IRNode).init(allocator), + .blocks = BlockStack.init(allocator), + .value_stack = std.ArrayList(*IRNode).init(allocator), + // .label_continuations = std.ArrayList(u32).init(allocator), + .pending_continuation_edges = std.ArrayList(PendingContinuationEdge).init(allocator), + .is_unreachable = false, + .locals = std.ArrayList(?*IRNode).init(allocator), + .unique_constants = UniqueValueToIRNodeMap.init(allocator), + .register_map = std.AutoHashMap(*const IRNode, u32).init(allocator), + .scratch_node_list_1 = std.ArrayList(*IRNode).init(allocator), + .scratch_node_list_2 = std.ArrayList(*IRNode).init(allocator), + }; + } - fn deinit(self: *IntermediateCompileData) void { - // self.all_nodes.deinit(); - self.blocks.deinit(); - self.value_stack.deinit(); - // self.label_continuations.deinit(); - self.pending_continuation_edges.deinit(); - self.locals.deinit(); - self.unique_constants.deinit(); - self.scratch_node_list_1.deinit(); - self.scratch_node_list_2.deinit(); - } + fn warmup(self: *IntermediateCompileData, func_def: FunctionDefinition, module_def: ModuleDefinition) AllocError!void { + try self.locals.appendNTimes(null, func_def.numParamsAndLocals(module_def)); + try self.scratch_node_list_1.ensureTotalCapacity(4096); + try self.scratch_node_list_2.ensureTotalCapacity(4096); + try self.register_map.ensureTotalCapacity(1024); + // try self.label_continuations.append(func_def.continuation); + self.is_unreachable = false; + } - fn popPushValueStackNodes(self: *IntermediateCompileData, node: *IRNode, num_consumed: usize, num_pushed: usize) AllocError!void { - if (self.is_unreachable) { - return; - } + fn reset(self: *IntermediateCompileData) void { + // self.all_nodes.clearRetainingCapacity(); + self.blocks.reset(); + self.value_stack.clearRetainingCapacity(); + // self.label_continuations.clearRetainingCapacity(); + self.pending_continuation_edges.clearRetainingCapacity(); + self.locals.clearRetainingCapacity(); + self.unique_constants.clearRetainingCapacity(); + self.register_map.clearRetainingCapacity(); + self.scratch_node_list_1.clearRetainingCapacity(); + self.scratch_node_list_2.clearRetainingCapacity(); + } - var edges_buffer: [8]*IRNode = undefined; // 8 should be more stack slots than any one instruction can pop - std.debug.assert(num_consumed <= edges_buffer.len); + fn deinit(self: *IntermediateCompileData) void { + // self.all_nodes.deinit(); + self.blocks.deinit(); + self.value_stack.deinit(); + // self.label_continuations.deinit(); + self.pending_continuation_edges.deinit(); + self.locals.deinit(); + self.unique_constants.deinit(); + self.register_map.deinit(); + self.scratch_node_list_1.deinit(); + self.scratch_node_list_2.deinit(); + } - var edges = edges_buffer[0..num_consumed]; - for (edges) |*e| { - e.* = self.value_stack.pop(); - } - try node.pushEdges(.In, edges, self.allocator); - for (edges) |e| { - var consumer_edges = [_]*IRNode{node}; - try e.pushEdges(.Out, &consumer_edges, self.allocator); - } - try self.value_stack.appendNTimes(node, num_pushed); + fn popPushValueStackNodes(self: *IntermediateCompileData, node: *IRNode, num_consumed: usize, num_pushed: usize) AllocError!void { + if (self.is_unreachable) { + return; } - fn foldConstant(self: *IntermediateCompileData, mir: *ModuleIR, comptime valtype: ValType, instruction_index: u32, instruction: Instruction) AllocError!*IRNode { - var val: TaggedVal = undefined; - val.type = valtype; - val.val = switch (valtype) { - .I32 => Val{ .I32 = instruction.immediate.ValueI32 }, - .I64 => Val{ .I64 = instruction.immediate.ValueI64 }, - .F32 => Val{ .F32 = instruction.immediate.ValueF32 }, - .F64 => Val{ .F64 = instruction.immediate.ValueF64 }, - .V128 => Val{ .V128 = instruction.immediate.ValueVec }, - else => @compileError("Unsupported const instruction"), - }; + var edges_buffer: [8]*IRNode = undefined; // 8 should be more stack slots than any one instruction can pop + std.debug.assert(num_consumed <= edges_buffer.len); - var res = try self.unique_constants.getOrPut(val); - if (res.found_existing == false) { - var node = try IRNode.createWithInstruction(mir, instruction_index); - res.value_ptr.* = node; - } - if (self.is_unreachable == false) { - try self.value_stack.append(res.value_ptr.*); - } - return res.value_ptr.*; + var edges = edges_buffer[0..num_consumed]; + for (edges) |*e| { + e.* = self.value_stack.pop(); } - - fn addPendingEdgeLabel(self: *IntermediateCompileData, node: *IRNode, label_id: u32) !void { - const last_block_index = self.blocks.blocks.items.len - 1; - var continuation: u32 = self.blocks.blocks.items[last_block_index - label_id].continuation; - try self.pending_continuation_edges.append(PendingContinuationEdge{ - .node = node, - .continuation = continuation, - }); + try node.pushEdges(.In, edges, self.allocator); + for (edges) |e| { + var consumer_edges = [_]*IRNode{node}; + try e.pushEdges(.Out, &consumer_edges, self.allocator); } + try self.value_stack.appendNTimes(node, num_pushed); + } - fn addPendingEdgeContinuation(self: *IntermediateCompileData, node: *IRNode, continuation: u32) !void { - try self.pending_continuation_edges.append(PendingContinuationEdge{ - .node = node, - .continuation = continuation, - }); + fn foldConstant(self: *IntermediateCompileData, compiler: *FunctionCompiler, comptime valtype: ValType, instruction_index: u32, instruction: Instruction) AllocError!*IRNode { + var val: TaggedVal = undefined; + val.type = valtype; + val.val = switch (valtype) { + .I32 => Val{ .I32 = instruction.immediate.ValueI32 }, + .I64 => Val{ .I64 = instruction.immediate.ValueI64 }, + .F32 => Val{ .F32 = instruction.immediate.ValueF32 }, + .F64 => Val{ .F64 = instruction.immediate.ValueF64 }, + .V128 => Val{ .V128 = instruction.immediate.ValueVec }, + else => @compileError("Unsupported const instruction"), + }; + + var res = try self.unique_constants.getOrPut(val); + if (res.found_existing == false) { + var node = try IRNode.createWithInstruction(compiler, instruction_index); + res.value_ptr.* = node; } - }; + if (self.is_unreachable == false) { + try self.value_stack.append(res.value_ptr.*); + } + return res.value_ptr.*; + } + + fn addPendingEdgeLabel(self: *IntermediateCompileData, node: *IRNode, label_id: u32) !void { + const last_block_index = self.blocks.blocks.items.len - 1; + var continuation: usize = self.blocks.blocks.items[last_block_index - label_id].continuation; + try self.pending_continuation_edges.append(PendingContinuationEdge{ + .node = node, + .continuation = continuation, + }); + } + + fn addPendingEdgeContinuation(self: *IntermediateCompileData, node: *IRNode, continuation: u32) !void { + try self.pending_continuation_edges.append(PendingContinuationEdge{ + .node = node, + .continuation = continuation, + }); + } +}; +const FunctionCompiler = struct { allocator: std.mem.Allocator, module_def: *const ModuleDefinition, - functions: std.ArrayList(IRFunction), ir: StableArray(IRNode), - // instructions: std.ArrayList(RegInstruction), - - fn init(allocator: std.mem.Allocator, module_def: *const ModuleDefinition) ModuleIR { - return ModuleIR{ + fn init(allocator: std.mem.Allocator, module_def: *const ModuleDefinition) FunctionCompiler { + return FunctionCompiler{ .allocator = allocator, .module_def = module_def, - .functions = std.ArrayList(IRFunction).init(allocator), .ir = StableArray(IRNode).init(1024 * 1024 * 8), }; } - fn deinit(mir: *ModuleIR) void { - for (mir.functions.items) |*func| { - func.deinit(); - } - mir.functions.deinit(); - for (mir.ir.items) |node| { - node.deinit(mir.allocator); + fn deinit(compiler: *FunctionCompiler) void { + for (compiler.ir.items) |node| { + node.deinit(compiler.allocator); } - mir.ir.deinit(); + compiler.ir.deinit(); } - fn compile(mir: *ModuleIR) AllocError!void { - var compile_data = IntermediateCompileData.init(mir.allocator); + fn compile(compiler: *FunctionCompiler, store: *FunctionStore) AllocError!void { + var compile_data = IntermediateCompileData.init(compiler.allocator); defer compile_data.deinit(); - for (0..mir.module_def.functions.items.len) |i| { - std.debug.print("mir.module_def.functions.items.len: {}, i: {}\n\n", .{ mir.module_def.functions.items.len, i }); - try mir.compileFunc(i, &compile_data); + // TODO could + for (0..compiler.module_def.functions.items.len) |i| { + std.debug.print("compiler.module_def.functions.items.len: {}, i: {}\n\n", .{ compiler.module_def.functions.items.len, i }); + var function_ir = try compiler.compileFunc(i, &compile_data); + if (function_ir.ir_root != null) { + try function_ir.regalloc(&compile_data, compiler.allocator); + try function_ir.codegen(store, compile_data, compiler.module_def.*, compiler.allocator); + } compile_data.reset(); } } - fn compileFunc(mir: *ModuleIR, index: usize, compile_data: *IntermediateCompileData) AllocError!void { + fn compileFunc(compiler: *FunctionCompiler, index: usize, compile_data: *IntermediateCompileData) AllocError!FunctionIR { const UniqueValueToIRNodeMap = std.HashMap(TaggedVal, *IRNode, TaggedVal.HashMapContext, std.hash_map.default_max_load_percentage); const Helpers = struct { @@ -768,25 +785,25 @@ const ModuleIR = struct { } }; - const func: *const FunctionDefinition = &mir.module_def.functions.items[index]; - const func_type: *const FunctionTypeDefinition = func.typeDefinition(mir.module_def.*); + const func: *const FunctionDefinition = &compiler.module_def.functions.items[index]; + const func_type: *const FunctionTypeDefinition = func.typeDefinition(compiler.module_def.*); std.debug.print("compiling func index {}\n", .{index}); - try compile_data.warmup(func.*, mir.module_def.*); + try compile_data.warmup(func.*, compiler.module_def.*); try compile_data.blocks.pushBlock(func.continuation); var locals = compile_data.locals.items; // for convenience later // Lets us collapse multiple const IR nodes with the same type/value into a single one - var unique_constants = UniqueValueToIRNodeMap.init(mir.allocator); + var unique_constants = UniqueValueToIRNodeMap.init(compiler.allocator); defer unique_constants.deinit(); - const instructions: []Instruction = func.instructions(mir.module_def.*); + const instructions: []Instruction = func.instructions(compiler.module_def.*); if (instructions.len == 0) { std.log.warn("Skipping function with no instructions (index {}).", .{index}); - return; + return FunctionIR{}; } var ir_root: ?*IRNode = null; @@ -796,7 +813,7 @@ const ModuleIR = struct { var node: ?*IRNode = null; if (Helpers.opcodeHasDefaultIRMapping(instruction.opcode)) { - node = try IRNode.createWithInstruction(mir, instruction_index); + node = try IRNode.createWithInstruction(compiler, instruction_index); } std.debug.print("opcode: {}\n", .{instruction.opcode}); @@ -826,7 +843,7 @@ const ModuleIR = struct { std.debug.assert(phi_nodes.items.len == 0); for (0..instruction.immediate.If.num_returns) |_| { - try phi_nodes.append(try IRNode.createPhi(mir)); + try phi_nodes.append(try IRNode.createPhi(compiler)); } try compile_data.blocks.pushBlockWithPhi(instruction.immediate.If.end_continuation, phi_nodes.items[0..]); @@ -859,7 +876,7 @@ const ModuleIR = struct { // the last End opcode returns the values on the stack // if (compile_data.label_continuations.items.len == 1) { if (compile_data.blocks.blocks.items.len == 1) { - node = try IRNode.createStandalone(mir, .Return); + node = try IRNode.createStandalone(compiler, .Return); try compile_data.popPushValueStackNodes(node.?, func_type.getReturns().len, 0); // _ = compile_data.label_continuations.pop(); } @@ -913,7 +930,7 @@ const ModuleIR = struct { // var continuation_edges: std.ArrayList(*IRNode).init(allocator); // defer continuation_edges.deinit(); - const immediates: *const BranchTableImmediates = &mir.module_def.code.branch_table.items[instruction.immediate.Index]; + const immediates: *const BranchTableImmediates = &compiler.module_def.code.branch_table.items[instruction.immediate.Index]; try compile_data.addPendingEdgeLabel(node.?, immediates.fallback_id); for (immediates.label_ids.items) |continuation| { @@ -933,8 +950,8 @@ const ModuleIR = struct { compile_data.is_unreachable = true; }, .Call => { - const calling_func_def: *const FunctionDefinition = &mir.module_def.functions.items[index]; - const calling_func_type: *const FunctionTypeDefinition = calling_func_def.typeDefinition(mir.module_def.*); + const calling_func_def: *const FunctionDefinition = &compiler.module_def.functions.items[index]; + const calling_func_type: *const FunctionTypeDefinition = calling_func_def.typeDefinition(compiler.module_def.*); const num_returns: usize = calling_func_type.getReturns().len; const num_params: usize = calling_func_type.getParams().len; @@ -948,19 +965,19 @@ const ModuleIR = struct { }, .I32_Const => { assert(node == null); - node = try compile_data.foldConstant(mir, .I32, instruction_index, instruction); + node = try compile_data.foldConstant(compiler, .I32, instruction_index, instruction); }, .I64_Const => { assert(node == null); - node = try compile_data.foldConstant(mir, .I64, instruction_index, instruction); + node = try compile_data.foldConstant(compiler, .I64, instruction_index, instruction); }, .F32_Const => { assert(node == null); - node = try compile_data.foldConstant(mir, .F32, instruction_index, instruction); + node = try compile_data.foldConstant(compiler, .F32, instruction_index, instruction); }, .F64_Const => { assert(node == null); - node = try compile_data.foldConstant(mir, .F64, instruction_index, instruction); + node = try compile_data.foldConstant(compiler, .F64, instruction_index, instruction); }, .I32_Eq, .I32_NE, @@ -1011,7 +1028,7 @@ const ModuleIR = struct { if (compile_data.is_unreachable == false) { const local: *?*IRNode = &locals[instruction.immediate.Index]; if (local.* == null) { - local.* = try IRNode.createWithInstruction(mir, instruction_index); + local.* = try IRNode.createWithInstruction(compiler, instruction_index); } node = local.*; try compile_data.value_stack.append(node.?); @@ -1086,29 +1103,252 @@ const ModuleIR = struct { // } // } - try mir.functions.append(IRFunction.init( - @intCast(index), - ir_root.?, - mir.allocator, - )); + return FunctionIR{ + .def_index = index, + .type_def_index = func.type_index, + .ir_root = ir_root, + }; + + // return FunctionIR.init( + // index, + // func.type_index, + // ir_root.?, + // compiler.allocator, + // ); + + // try compiler.functions.append(FunctionIR.init( + // index, + // func.type_index, + // ir_root.?, + // compiler.allocator, + // )); + + // try compiler.functions.items[compiler.functions.items.len - 1].regalloc(compiler.allocator); + } +}; + +const FunctionInstance = struct { + type_def_index: usize, + def_index: usize, + instructions_begin: usize, + instructions_end: usize, + local_types_begin: usize, + local_types_end: usize, + + fn instructions(func: FunctionInstance, store: FunctionStore) []RegInstruction { + return store.instructions.items[func.instructions_begin..func.instructions_end]; + } + + fn localTypes(func: FunctionInstance, store: FunctionStore) []ValType { + return store.local_types.items[func.local_types_begin..func.local_types_end]; + } - try mir.functions.items[mir.functions.items.len - 1].regalloc(mir.allocator); + fn typeDefinition(func: FunctionInstance, module_def: ModuleDefinition) *const FunctionTypeDefinition { + return &module_def.types.items[func.type_def_index]; } + + fn definition(func: FunctionInstance, module_def: ModuleDefinition) *const FunctionDefinition { + return &module_def.functions.items[func.def_index]; + } +}; + +const CompiledFunctions = struct { + local_types: std.ArrayList(ValType), + instructions: std.ArrayList(RegInstruction), + instances: std.ArrayList(FunctionInstance), +}; + +const Label = struct { + // TODO figure out what this struct should be + // num_returns: u32, + continuation: u32, + // start_offset_values: u32, +}; + +const CallFrame = struct { + func: *const FunctionInstance, + module_instance: *ModuleInstance, + num_returns: u32, + registers_begin: u32, // offset into registers + labels_begin: u32, // offset into labels +}; + +const MachineState = struct { + const AllocOpts = struct { + max_registers: usize, + max_labels: usize, + max_frames: usize, + }; + + registers: []Val, + labels: []Label, + frames: []CallFrame, + num_registers: u32, + num_labels: u16, + num_frames: u16, + mem: []u8, + allocator: std.mem.Allocator, + + fn init(allocator: std.mem.Allocator) MachineState { + return MachineState{ + .registers = &[_]Val{}, + .labels = &[_]Label{}, + .frames = &[_]CallFrame{}, + .num_registers = 0, + .num_labels = 0, + .num_frames = 0, + .mem = &[_]u8{}, + .allocator = allocator, + }; + } + + fn deinit(ms: *MachineState) void { + if (ms.mem.len > 0) { + ms.allocator.free(ms.mem); + } + } + + fn allocMemory(ms: *MachineState, opts: AllocOpts) AllocError!void { + const alignment = @max(@alignOf(Val), @alignOf(Label), @alignOf(CallFrame)); + const values_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_registers)) * @sizeOf(Val), alignment); + const labels_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_labels)) * @sizeOf(Label), alignment); + const frames_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_frames)) * @sizeOf(CallFrame), alignment); + const total_alloc_size: usize = values_alloc_size + labels_alloc_size + frames_alloc_size; + + const begin_labels = values_alloc_size; + const begin_frames = values_alloc_size + labels_alloc_size; + + ms.mem = try ms.allocator.alloc(u8, total_alloc_size); + ms.registers.ptr = @as([*]Val, @alignCast(@ptrCast(ms.mem.ptr))); + ms.registers.len = opts.max_registers; + ms.labels.ptr = @as([*]Label, @alignCast(@ptrCast(ms.mem[begin_labels..].ptr))); + ms.labels.len = opts.max_labels; + ms.frames.ptr = @as([*]CallFrame, @alignCast(@ptrCast(ms.mem[begin_frames..].ptr))); + ms.frames.len = opts.max_frames; + } + + fn checkExhausted(ms: MachineState, extra_registers: u32) TrapError!void { + if (ms.num_registers + extra_registers >= ms.registers.len) { + return error.TrapStackExhausted; + } + } + + fn reset(ms: *MachineState) void { + ms.num_registers = 0; + ms.num_labels = 0; + ms.num_frames = 0; + } + + fn get(ms: MachineState, register_local: u32) Val { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + return ms.registers[slot]; + } + + fn getI32(ms: MachineState, register_local: u32) i32 { + return ms.get(register_local).I32; + } + + fn getI64(ms: MachineState, register_local: u32) i64 { + return ms.get(register_local).I64; + } + + fn getF32(ms: MachineState, register_local: u32) f32 { + return ms.get(register_local).F32; + } + + fn getF64(ms: MachineState, register_local: u32) f64 { + return ms.get(register_local).F64; + } + + fn set(ms: *MachineState, register_local: u32, val: Val) void { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + ms.registers[slot] = val; + } + + fn setI32(ms: *MachineState, register_local: u32, val: i32) void { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + ms.registers[slot].I32 = val; + } + + fn setI64(ms: *MachineState, register_local: u32, val: i64) void { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + ms.registers[slot].I64 = val; + } + + fn setF32(ms: *MachineState, register_local: u32, val: f32) void { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + ms.registers[slot].F32 = val; + } + + fn setF64(ms: *MachineState, register_local: u32, val: f64) void { + var frame: *CallFrame = topFrame(); + var slot = frame.registers_begin + register_local; + ms.registers[slot].F64 = val; + } + + fn topFrame(ms: MachineState) *CallFrame { + return &ms.frames[ms.num_frames - 1]; + } +}; + +const FunctionStore = struct { + local_types: std.ArrayList(ValType), + instructions: std.ArrayList(RegInstruction), + instances: std.ArrayList(FunctionInstance), }; pub const RegisterVM = struct { + functions: FunctionStore, + ms: MachineState, + + fn fromVM(vm: *VM) *RegisterVM { + return @as(*RegisterVM, @alignCast(@ptrCast(vm.impl))); + } + pub fn init(vm: *VM) void { - _ = vm; + var self: *RegisterVM = fromVM(vm); + + self.functions.local_types = std.ArrayList(ValType).init(vm.allocator); + self.functions.instructions = std.ArrayList(RegInstruction).init(vm.allocator); + self.functions.instances = std.ArrayList(FunctionInstance).init(vm.allocator); + self.ms = MachineState.init(vm.allocator); } pub fn deinit(vm: *VM) void { - _ = vm; + var self: *RegisterVM = fromVM(vm); + + self.functions.local_types.deinit(); + self.functions.instructions.deinit(); + self.functions.instances.deinit(); + self.ms.deinit(); } pub fn instantiate(vm: *VM, module: *ModuleInstance, opts: ModuleInstantiateOpts) anyerror!void { - _ = vm; - _ = module; - _ = opts; + var self: *RegisterVM = fromVM(vm); + + const stack_size = if (opts.stack_size > 0) opts.stack_size else 1024 * 128; + const stack_size_f = @as(f64, @floatFromInt(stack_size)); + + try self.ms.allocMemory(.{ + .max_registers = @as(usize, @intFromFloat(stack_size_f * 0.85)), + .max_labels = @as(usize, @intFromFloat(stack_size_f * 0.14)), + .max_frames = @as(usize, @intFromFloat(stack_size_f * 0.01)), + }); + + var compiler = FunctionCompiler.init(vm.allocator, module.module_def); + defer compiler.deinit(); + + try compiler.compile(&self.functions); + + // wasm bytecode -> IR graph -> register-assigned IR graph -> + + // TODO create functions? + return error.Unimplemented; } @@ -1161,25 +1401,9 @@ pub const RegisterVM = struct { } pub fn findFuncTypeDef(vm: *VM, module: *ModuleInstance, local_func_index: usize) *const FunctionTypeDefinition { - _ = vm; - _ = module; - _ = local_func_index; - return &dummy_func_type_def; + var self: *RegisterVM = fromVM(vm); + return self.functions.instances.items[local_func_index].typeDefinition(module.module_def.*); } - - pub fn compile(vm: *RegisterVM, module_def: ModuleDefinition) AllocError!void { - var mir = ModuleIR.init(vm.allocator, module_def); - defer mir.deinit(); - - try mir.compile(); - - // wasm bytecode -> IR graph -> register-assigned IR graph -> - } -}; - -const dummy_func_type_def = FunctionTypeDefinition{ - .types = undefined, - .num_params = 0, }; // register instructions get a slice of the overall set of register slots, which are pointers to actual @@ -1214,10 +1438,10 @@ fn runTestWithViz(wasm_filepath: []const u8, viz_dir: []const u8) !void { try module_def.decode(wasm_data); - var mir = ModuleIR.init(allocator, module_def); - defer mir.deinit(); - try mir.compile(); - for (mir.functions.items, 0..) |func, i| { + var compiler = FunctionCompiler.init(allocator, module_def); + defer compiler.deinit(); + try compiler.compile(); + for (compiler.functions.items, 0..) |func, i| { var viz_path_buffer: [256]u8 = undefined; const viz_path = std.fmt.bufPrint(&viz_path_buffer, "{s}\\viz_{}.txt", .{ viz_dir, i }) catch unreachable; std.debug.print("gen graph for func {}\n", .{i}); @@ -1251,10 +1475,10 @@ fn runTestWithViz(wasm_filepath: []const u8, viz_dir: []const u8) !void { // // try module_def.decode(wasm_data); -// // var mir = ModuleIR.init(allocator, &module_def); -// // defer mir.deinit(); -// // try mir.compile(); -// // for (mir.functions.items, 0..) |func, i| { +// // var compiler = FunctionCompiler.init(allocator, &module_def); +// // defer compiler.deinit(); +// // try compiler.compile(); +// // for (compiler.functions.items, 0..) |func, i| { // // var viz_path_buffer: [256]u8 = undefined; // // const path_format = // // \\E:\Dev\zig_projects\bytebox\viz\viz_{}.txt diff --git a/src/vm_stack.zig b/src/vm_stack.zig index ddeb8e6..d1912ec 100644 --- a/src/vm_stack.zig +++ b/src/vm_stack.zig @@ -83,9 +83,9 @@ const DebugTraceStackVM = struct { }; const FunctionInstance = struct { - type_def_index: u32, - def_index: u32, - instructions_begin: u32, + type_def_index: usize, + def_index: usize, + instructions_begin: usize, local_types: std.ArrayList(ValType), }; @@ -1032,7 +1032,7 @@ const InstructionFuncs = struct { return FuncCallData{ .code = module_instance.module_def.code.instructions.items.ptr, - .continuation = func.instructions_begin, + .continuation = @intCast(func.instructions_begin), }; } @@ -5384,7 +5384,7 @@ pub const StackVM = struct { const name_section: *const NameCustomSection = &frame.module_instance.module_def.name_section; const module_name = name_section.getModuleName(); - const func_name_index: u32 = frame.func.def_index + @as(u32, @intCast(frame.module_instance.module_def.imports.functions.items.len)); + const func_name_index: usize = frame.func.def_index + frame.module_instance.module_def.imports.functions.items.len; const function_name = name_section.findFunctionName(func_name_index); try writer.print("{}: {s}!{s}\n", .{ reverse_index, module_name, function_name }); @@ -5421,11 +5421,11 @@ pub const StackVM = struct { } try self.stack.pushFrame(&func, module, param_types, func.local_types.items, func_type.calcNumReturns()); - try self.stack.pushLabel(@as(u32, @intCast(return_types.len)), func_def.continuation); + try self.stack.pushLabel(@as(u32, @intCast(return_types.len)), @intCast(func_def.continuation)); DebugTrace.traceFunction(module, self.stack.num_frames, func.def_index); - try InstructionFuncs.run(func.instructions_begin, module.module_def.code.instructions.items.ptr, &self.stack); + try InstructionFuncs.run(@intCast(func.instructions_begin), module.module_def.code.instructions.items.ptr, &self.stack); if (returns_slice.len > 0) { var index: i32 = @as(i32, @intCast(returns_slice.len - 1));