diff --git a/compiler/hash-typecheck/src/new/passes/scope_discovery.rs b/compiler/hash-typecheck/src/new/passes/scope_discovery.rs index 6ca5d813d..19ec1105d 100644 --- a/compiler/hash-typecheck/src/new/passes/scope_discovery.rs +++ b/compiler/hash-typecheck/src/new/passes/scope_discovery.rs @@ -510,14 +510,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> { } } - /// Add a declaration node `a := b` to the given `stack_id` (which is + /// Add a pattern node to the given `stack_id` (which is /// "current"). /// - /// This adds the declaration as a set of stack members, taking into account - /// all of the pattern bindings. It adds a set of tuples `(AstNodeId, - /// StackMemberData)`, one for each binding, where the `AstNodeId` is - /// the `AstNodeId` of the binding pattern node. - fn add_declaration_node_to_stack(&self, node: AstNodeRef, stack_id: StackId) { + /// This adds the pattern binds as a set of stack members. It adds a set of + /// tuples `(AstNodeId, StackMemberData)`, one for each binding, where + /// the `AstNodeId` is the `AstNodeId` of the binding pattern node. + fn add_pat_node_binds_to_stack(&self, node: AstNodeRef, stack_id: StackId) { self.stack_members.modify_fast(stack_id, |members| { let members = match members { Some(members) => members, @@ -528,7 +527,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> { // Add each stack member to the stack_members vector let mut found_members = smallvec![]; - self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members); + self.add_stack_members_in_pat_to_buf(node, &mut found_members); for (node_id, stack_member) in found_members { members.push((node_id, stack_member)); } @@ -548,7 +547,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> { FnDef, TyFnDef, BodyBlock, - MergeDeclaration + MergeDeclaration, + MatchCase ); type DeclarationRet = (); @@ -581,7 +581,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> { } DefId::Stack(stack_id) => { walk_with_name_hint()?; - self.add_declaration_node_to_stack(node, stack_id) + self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id) } DefId::Fn(_) => { panic_on_span!( @@ -595,6 +595,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> { Ok(()) } + type MatchCaseRet = (); + fn visit_match_case( + &self, + node: AstNodeRef, + ) -> Result { + match self.get_current_def() { + DefId::Stack(_) => { + // A match case creates its own stack scope. + let stack_id = self.stack_ops().create_stack(); + self.enter_def(node, stack_id, || { + self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id); + walk::walk_match_case(self, node) + })?; + Ok(()) + } + _ => { + panic_on_span!( + self.node_location(node), + self.source_map(), + "found match in non-stack scope" + ) + } + } + } + type ModuleRet = (); fn visit_module( &self, diff --git a/compiler/hash-typecheck/src/new/passes/symbol_resolution.rs b/compiler/hash-typecheck/src/new/passes/symbol_resolution.rs index b056140ee..1a37e2f37 100644 --- a/compiler/hash-typecheck/src/new/passes/symbol_resolution.rs +++ b/compiler/hash-typecheck/src/new/passes/symbol_resolution.rs @@ -6,7 +6,10 @@ use hash_ast::{ ast, ast_visitor_default_impl, visitor::{walk, AstVisitor}, }; -use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv}; +use hash_types::new::{ + environment::{context::ScopeKind, env::AccessToEnv}, + scopes::StackMemberId, +}; use super::ast_pass::AstPass; use crate::{ @@ -48,6 +51,75 @@ impl<'tc> SymbolResolutionPass<'tc> { } } +impl<'tc> SymbolResolutionPass<'tc> { + /// Run a function for each stack member in the given pattern. + /// + /// The stack members are found in the `AstInfo` store, specifically the + /// `stack_members` map. They are looked up using the IDs of the pattern + /// binds, as added by the `add_stack_members_in_pat_to_buf` method of the + /// `ScopeDiscoveryPass`. + fn for_each_stack_member_of_pat( + &self, + node: ast::AstNodeRef, + f: impl Fn(StackMemberId) + Copy, + ) { + let for_spread_pat = |spread: &ast::AstNode| { + if let Some(name) = &spread.name { + if let Some(member_id) = + self.ast_info().stack_members().get_data_by_node(name.ast_ref().id()) + { + f(member_id); + } + } + }; + match node.body() { + ast::Pat::Binding(_) => { + if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id()) + { + f(member_id); + } + } + ast::Pat::Tuple(tuple_pat) => { + for (index, entry) in tuple_pat.fields.ast_ref_iter().enumerate() { + if let Some(spread_node) = &tuple_pat.spread && spread_node.position == index { + for_spread_pat(spread_node); + } + self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f); + } + } + ast::Pat::Constructor(constructor_pat) => { + for (index, field) in constructor_pat.fields.ast_ref_iter().enumerate() { + if let Some(spread_node) = &constructor_pat.spread && spread_node.position == index { + for_spread_pat(spread_node); + } + self.for_each_stack_member_of_pat(field.pat.ast_ref(), f); + } + } + ast::Pat::List(list_pat) => { + for (index, pat) in list_pat.fields.ast_ref_iter().enumerate() { + if let Some(spread_node) = &list_pat.spread && spread_node.position == index { + for_spread_pat(spread_node); + } + self.for_each_stack_member_of_pat(pat, f); + } + } + ast::Pat::Or(or_pat) => { + if let Some(pat) = or_pat.variants.get(0) { + self.for_each_stack_member_of_pat(pat.ast_ref(), f) + } + } + ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f), + ast::Pat::Wild(_) => { + if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id()) + { + f(member_id); + } + } + ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {} + } + } +} + /// @@Temporary: for now this visitor just walks the AST and enters scopes. The /// next step is to resolve symbols in these scopes!. impl ast::AstVisitor for SymbolResolutionPass<'_> { @@ -61,6 +133,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> { FnDef, TyFnDef, BodyBlock, + MatchCase, ); type ModuleRet = (); @@ -157,12 +230,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> { node: ast::AstNodeRef, ) -> Result { // If we are in a stack, then we need to add the declaration to the - // stack's scope. + // stack's scope. Otherwise the declaration is handled higher up. if let ScopeKind::Stack(_) = self.context().get_scope_kind() { - let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap(); - self.context_ops().add_stack_binding(member); + self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| { + self.context_ops().add_stack_binding(member); + }); } walk::walk_declaration(self, node)?; Ok(()) } + + type MatchCaseRet = (); + fn visit_match_case( + &self, + node: ast::AstNodeRef, + ) -> Result { + let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap(); + // Each match case has its own scope, so we need to enter it, and add all the + // pattern bindings to the context. + self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || { + self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| { + self.context_ops().add_stack_binding(member); + }); + walk::walk_match_case(self, node)?; + Ok(()) + }) + } }