Skip to content

Commit e442db5

Browse files
committed
Perform scope discovery on match cases too
1 parent 01bc31d commit e442db5

File tree

2 files changed

+120
-13
lines changed

2 files changed

+120
-13
lines changed

compiler/hash-typecheck/src/new/passes/scope_discovery.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
510510
}
511511
}
512512

513-
/// Add a declaration node `a := b` to the given `stack_id` (which is
513+
/// Add a pattern node to the given `stack_id` (which is
514514
/// "current").
515515
///
516-
/// This adds the declaration as a set of stack members, taking into account
517-
/// all of the pattern bindings. It adds a set of tuples `(AstNodeId,
518-
/// StackMemberData)`, one for each binding, where the `AstNodeId` is
519-
/// the `AstNodeId` of the binding pattern node.
520-
fn add_declaration_node_to_stack(&self, node: AstNodeRef<ast::Declaration>, stack_id: StackId) {
516+
/// This adds the pattern binds as a set of stack members. It adds a set of
517+
/// tuples `(AstNodeId, StackMemberData)`, one for each binding, where
518+
/// the `AstNodeId` is the `AstNodeId` of the binding pattern node.
519+
fn add_pat_node_binds_to_stack(&self, node: AstNodeRef<ast::Pat>, stack_id: StackId) {
521520
self.stack_members.modify_fast(stack_id, |members| {
522521
let members = match members {
523522
Some(members) => members,
@@ -528,7 +527,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
528527

529528
// Add each stack member to the stack_members vector
530529
let mut found_members = smallvec![];
531-
self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members);
530+
self.add_stack_members_in_pat_to_buf(node, &mut found_members);
532531
for (node_id, stack_member) in found_members {
533532
members.push((node_id, stack_member));
534533
}
@@ -548,7 +547,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
548547
FnDef,
549548
TyFnDef,
550549
BodyBlock,
551-
MergeDeclaration
550+
MergeDeclaration,
551+
MatchCase
552552
);
553553

554554
type DeclarationRet = ();
@@ -581,7 +581,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
581581
}
582582
DefId::Stack(stack_id) => {
583583
walk_with_name_hint()?;
584-
self.add_declaration_node_to_stack(node, stack_id)
584+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id)
585585
}
586586
DefId::Fn(_) => {
587587
panic_on_span!(
@@ -595,6 +595,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
595595
Ok(())
596596
}
597597

598+
type MatchCaseRet = ();
599+
fn visit_match_case(
600+
&self,
601+
node: AstNodeRef<ast::MatchCase>,
602+
) -> Result<Self::MatchCaseRet, Self::Error> {
603+
match self.get_current_def() {
604+
DefId::Stack(_) => {
605+
// A match case creates its own stack scope.
606+
let stack_id = self.stack_ops().create_stack();
607+
self.enter_def(node, stack_id, || {
608+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id);
609+
walk::walk_match_case(self, node)
610+
})?;
611+
Ok(())
612+
}
613+
_ => {
614+
panic_on_span!(
615+
self.node_location(node),
616+
self.source_map(),
617+
"found match in non-stack scope"
618+
)
619+
}
620+
}
621+
}
622+
598623
type ModuleRet = ();
599624
fn visit_module(
600625
&self,

compiler/hash-typecheck/src/new/passes/symbol_resolution.rs

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use hash_ast::{
66
ast, ast_visitor_default_impl,
77
visitor::{walk, AstVisitor},
88
};
9-
use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv};
9+
use hash_types::new::{
10+
environment::{context::ScopeKind, env::AccessToEnv},
11+
scopes::StackMemberId,
12+
};
1013

1114
use super::ast_pass::AstPass;
1215
use crate::{
@@ -48,6 +51,66 @@ impl<'tc> SymbolResolutionPass<'tc> {
4851
}
4952
}
5053

54+
impl<'tc> SymbolResolutionPass<'tc> {
55+
/// Run a function for each stack member in the given pattern.
56+
///
57+
/// The stack members are found in the `AstInfo` store, specifically the
58+
/// `stack_members` map. They are looked up using the IDs of the pattern
59+
/// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
60+
/// `ScopeDiscoveryPass`.
61+
fn for_each_stack_member_of_pat(
62+
&self,
63+
node: ast::AstNodeRef<ast::Pat>,
64+
f: impl Fn(StackMemberId) + Copy,
65+
) {
66+
match node.body() {
67+
ast::Pat::Binding(_) => {
68+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
69+
{
70+
f(member_id);
71+
}
72+
}
73+
ast::Pat::Tuple(tuple_pat) => {
74+
for entry in tuple_pat.fields.ast_ref_iter() {
75+
self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f);
76+
}
77+
}
78+
ast::Pat::Constructor(constructor_pat) => {
79+
for field in constructor_pat.fields.ast_ref_iter() {
80+
self.for_each_stack_member_of_pat(field.pat.ast_ref(), f);
81+
}
82+
}
83+
ast::Pat::List(list_pat) => {
84+
for pat in list_pat.fields.ast_ref_iter() {
85+
self.for_each_stack_member_of_pat(pat, f);
86+
}
87+
}
88+
ast::Pat::Or(or_pat) => {
89+
if let Some(pat) = or_pat.variants.get(0) {
90+
self.for_each_stack_member_of_pat(pat.ast_ref(), f)
91+
}
92+
}
93+
ast::Pat::Spread(spread_pat) => {
94+
if let Some(name) = spread_pat.name.as_ref() {
95+
if let Some(member_id) =
96+
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
97+
{
98+
f(member_id);
99+
}
100+
}
101+
}
102+
ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f),
103+
ast::Pat::Wild(_) => {
104+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
105+
{
106+
f(member_id);
107+
}
108+
}
109+
ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {}
110+
}
111+
}
112+
}
113+
51114
/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
52115
/// next step is to resolve symbols in these scopes!.
53116
impl ast::AstVisitor for SymbolResolutionPass<'_> {
@@ -61,6 +124,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
61124
FnDef,
62125
TyFnDef,
63126
BodyBlock,
127+
MatchCase,
64128
);
65129

66130
type ModuleRet = ();
@@ -157,12 +221,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
157221
node: ast::AstNodeRef<ast::Declaration>,
158222
) -> Result<Self::DeclarationRet, Self::Error> {
159223
// If we are in a stack, then we need to add the declaration to the
160-
// stack's scope.
224+
// stack's scope. Otherwise the declaration is handled higher up.
161225
if let ScopeKind::Stack(_) = self.context().get_scope_kind() {
162-
let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap();
163-
self.context_ops().add_stack_binding(member);
226+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
227+
self.context_ops().add_stack_binding(member);
228+
});
164229
}
165230
walk::walk_declaration(self, node)?;
166231
Ok(())
167232
}
233+
234+
type MatchCaseRet = ();
235+
fn visit_match_case(
236+
&self,
237+
node: ast::AstNodeRef<ast::MatchCase>,
238+
) -> Result<Self::MatchCaseRet, Self::Error> {
239+
let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap();
240+
// Each match case has its own scope, so we need to enter it, and add all the
241+
// pattern bindings to the context.
242+
self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || {
243+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
244+
self.context_ops().add_stack_binding(member);
245+
});
246+
walk::walk_match_case(self, node)?;
247+
Ok(())
248+
})
249+
}
168250
}

0 commit comments

Comments
 (0)