Skip to content

Commit 889a0b0

Browse files
committed
Perform scope discovery on match cases too
1 parent 7550e13 commit 889a0b0

File tree

2 files changed

+120
-13
lines changed

2 files changed

+120
-13
lines changed

compiler/hash-typecheck/src/new/passes/scope_discovery.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
495495
}
496496
}
497497

498-
/// Add a declaration node `a := b` to the given `stack_id` (which is
498+
/// Add a pattern node to the given `stack_id` (which is
499499
/// "current").
500500
///
501-
/// This adds the declaration as a set of stack members, taking into account
502-
/// all of the pattern bindings. It adds a set of tuples `(AstNodeId,
503-
/// StackMemberData)`, one for each binding, where the `AstNodeId` is
504-
/// the `AstNodeId` of the binding pattern node.
505-
fn add_declaration_node_to_stack(&self, node: AstNodeRef<ast::Declaration>, stack_id: StackId) {
501+
/// This adds the pattern binds as a set of stack members. It adds a set of
502+
/// tuples `(AstNodeId, StackMemberData)`, one for each binding, where
503+
/// the `AstNodeId` is the `AstNodeId` of the binding pattern node.
504+
fn add_pat_node_binds_to_stack(&self, node: AstNodeRef<ast::Pat>, stack_id: StackId) {
506505
self.stack_members.modify_fast(stack_id, |members| {
507506
let members = match members {
508507
Some(members) => members,
@@ -513,7 +512,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
513512

514513
// Add each stack member to the stack_members vector
515514
let mut found_members = smallvec![];
516-
self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members);
515+
self.add_stack_members_in_pat_to_buf(node, &mut found_members);
517516
for (node_id, stack_member) in found_members {
518517
members.push((node_id, stack_member));
519518
}
@@ -533,7 +532,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
533532
FnDef,
534533
TyFnDef,
535534
BodyBlock,
536-
MergeDeclaration
535+
MergeDeclaration,
536+
MatchCase
537537
);
538538

539539
type DeclarationRet = ();
@@ -566,7 +566,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
566566
}
567567
DefId::Stack(stack_id) => {
568568
walk_with_name_hint()?;
569-
self.add_declaration_node_to_stack(node, stack_id)
569+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id)
570570
}
571571
DefId::Fn(_) => {
572572
panic_on_span!(
@@ -580,6 +580,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
580580
Ok(())
581581
}
582582

583+
type MatchCaseRet = ();
584+
fn visit_match_case(
585+
&self,
586+
node: AstNodeRef<ast::MatchCase>,
587+
) -> Result<Self::MatchCaseRet, Self::Error> {
588+
match self.get_current_def() {
589+
DefId::Stack(_) => {
590+
// A match case creates its own stack scope.
591+
let stack_id = self.stack_ops().create_stack();
592+
self.enter_def(node, stack_id, || {
593+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id);
594+
walk::walk_match_case(self, node)
595+
})?;
596+
Ok(())
597+
}
598+
_ => {
599+
panic_on_span!(
600+
self.node_location(node),
601+
self.source_map(),
602+
"found match in non-stack scope"
603+
)
604+
}
605+
}
606+
}
607+
583608
type ModuleRet = ();
584609
fn visit_module(
585610
&self,

compiler/hash-typecheck/src/new/passes/symbol_resolution.rs

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use hash_ast::{
66
ast, ast_visitor_default_impl,
77
visitor::{walk, AstVisitor},
88
};
9-
use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv};
9+
use hash_types::new::{
10+
environment::{context::ScopeKind, env::AccessToEnv},
11+
scopes::StackMemberId,
12+
};
1013

1114
use super::ast_pass::AstPass;
1215
use crate::{
@@ -48,6 +51,66 @@ impl<'tc> SymbolResolutionPass<'tc> {
4851
}
4952
}
5053

54+
impl<'tc> SymbolResolutionPass<'tc> {
55+
/// Run a function for each stack member in the given pattern.
56+
///
57+
/// The stack members are found in the `AstInfo` store, specifically the
58+
/// `stack_members` map. They are looked up using the IDs of the pattern
59+
/// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
60+
/// `ScopeDiscoveryPass`.
61+
fn for_each_stack_member_of_pat(
62+
&self,
63+
node: ast::AstNodeRef<ast::Pat>,
64+
f: impl Fn(StackMemberId) + Copy,
65+
) {
66+
match node.body() {
67+
ast::Pat::Binding(_) => {
68+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
69+
{
70+
f(member_id);
71+
}
72+
}
73+
ast::Pat::Tuple(tuple_pat) => {
74+
for entry in tuple_pat.fields.ast_ref_iter() {
75+
self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f);
76+
}
77+
}
78+
ast::Pat::Constructor(constructor_pat) => {
79+
for field in constructor_pat.fields.ast_ref_iter() {
80+
self.for_each_stack_member_of_pat(field.pat.ast_ref(), f);
81+
}
82+
}
83+
ast::Pat::List(list_pat) => {
84+
for pat in list_pat.fields.ast_ref_iter() {
85+
self.for_each_stack_member_of_pat(pat, f);
86+
}
87+
}
88+
ast::Pat::Or(or_pat) => {
89+
if let Some(pat) = or_pat.variants.get(0) {
90+
self.for_each_stack_member_of_pat(pat.ast_ref(), f)
91+
}
92+
}
93+
ast::Pat::Spread(spread_pat) => {
94+
if let Some(name) = spread_pat.name.as_ref() {
95+
if let Some(member_id) =
96+
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
97+
{
98+
f(member_id);
99+
}
100+
}
101+
}
102+
ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f),
103+
ast::Pat::Wild(_) => {
104+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
105+
{
106+
f(member_id);
107+
}
108+
}
109+
ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {}
110+
}
111+
}
112+
}
113+
51114
/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
52115
/// next step is to resolve symbols in these scopes!.
53116
impl ast::AstVisitor for SymbolResolutionPass<'_> {
@@ -61,6 +124,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
61124
FnDef,
62125
TyFnDef,
63126
BodyBlock,
127+
MatchCase,
64128
);
65129

66130
type ModuleRet = ();
@@ -157,12 +221,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
157221
node: ast::AstNodeRef<ast::Declaration>,
158222
) -> Result<Self::DeclarationRet, Self::Error> {
159223
// If we are in a stack, then we need to add the declaration to the
160-
// stack's scope.
224+
// stack's scope. Otherwise the declaration is handled higher up.
161225
if let ScopeKind::Stack(_) = self.context().get_scope_kind() {
162-
let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap();
163-
self.context_ops().add_stack_binding(member);
226+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
227+
self.context_ops().add_stack_binding(member);
228+
});
164229
}
165230
walk::walk_declaration(self, node)?;
166231
Ok(())
167232
}
233+
234+
type MatchCaseRet = ();
235+
fn visit_match_case(
236+
&self,
237+
node: ast::AstNodeRef<ast::MatchCase>,
238+
) -> Result<Self::MatchCaseRet, Self::Error> {
239+
let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap();
240+
// Each match case has its own scope, so we need to enter it, and add all the
241+
// pattern bindings to the context.
242+
self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || {
243+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
244+
self.context_ops().add_stack_binding(member);
245+
});
246+
walk::walk_match_case(self, node)?;
247+
Ok(())
248+
})
249+
}
168250
}

0 commit comments

Comments
 (0)