Skip to content

Commit d748cf5

Browse files
Merge pull request #639 from hash-org/fix-637
Fix: Perform scope discovery on match cases too
2 parents 01bc31d + 00a5a69 commit d748cf5

File tree

2 files changed

+129
-13
lines changed

2 files changed

+129
-13
lines changed

compiler/hash-typecheck/src/new/passes/scope_discovery.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
510510
}
511511
}
512512

513-
/// Add a declaration node `a := b` to the given `stack_id` (which is
513+
/// Add a pattern node to the given `stack_id` (which is
514514
/// "current").
515515
///
516-
/// This adds the declaration as a set of stack members, taking into account
517-
/// all of the pattern bindings. It adds a set of tuples `(AstNodeId,
518-
/// StackMemberData)`, one for each binding, where the `AstNodeId` is
519-
/// the `AstNodeId` of the binding pattern node.
520-
fn add_declaration_node_to_stack(&self, node: AstNodeRef<ast::Declaration>, stack_id: StackId) {
516+
/// This adds the pattern binds as a set of stack members. It adds a set of
517+
/// tuples `(AstNodeId, StackMemberData)`, one for each binding, where
518+
/// the `AstNodeId` is the `AstNodeId` of the binding pattern node.
519+
fn add_pat_node_binds_to_stack(&self, node: AstNodeRef<ast::Pat>, stack_id: StackId) {
521520
self.stack_members.modify_fast(stack_id, |members| {
522521
let members = match members {
523522
Some(members) => members,
@@ -528,7 +527,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
528527

529528
// Add each stack member to the stack_members vector
530529
let mut found_members = smallvec![];
531-
self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members);
530+
self.add_stack_members_in_pat_to_buf(node, &mut found_members);
532531
for (node_id, stack_member) in found_members {
533532
members.push((node_id, stack_member));
534533
}
@@ -548,7 +547,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
548547
FnDef,
549548
TyFnDef,
550549
BodyBlock,
551-
MergeDeclaration
550+
MergeDeclaration,
551+
MatchCase
552552
);
553553

554554
type DeclarationRet = ();
@@ -581,7 +581,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
581581
}
582582
DefId::Stack(stack_id) => {
583583
walk_with_name_hint()?;
584-
self.add_declaration_node_to_stack(node, stack_id)
584+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id)
585585
}
586586
DefId::Fn(_) => {
587587
panic_on_span!(
@@ -595,6 +595,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
595595
Ok(())
596596
}
597597

598+
type MatchCaseRet = ();
599+
fn visit_match_case(
600+
&self,
601+
node: AstNodeRef<ast::MatchCase>,
602+
) -> Result<Self::MatchCaseRet, Self::Error> {
603+
match self.get_current_def() {
604+
DefId::Stack(_) => {
605+
// A match case creates its own stack scope.
606+
let stack_id = self.stack_ops().create_stack();
607+
self.enter_def(node, stack_id, || {
608+
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id);
609+
walk::walk_match_case(self, node)
610+
})?;
611+
Ok(())
612+
}
613+
_ => {
614+
panic_on_span!(
615+
self.node_location(node),
616+
self.source_map(),
617+
"found match in non-stack scope"
618+
)
619+
}
620+
}
621+
}
622+
598623
type ModuleRet = ();
599624
fn visit_module(
600625
&self,

compiler/hash-typecheck/src/new/passes/symbol_resolution.rs

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use hash_ast::{
66
ast, ast_visitor_default_impl,
77
visitor::{walk, AstVisitor},
88
};
9-
use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv};
9+
use hash_types::new::{
10+
environment::{context::ScopeKind, env::AccessToEnv},
11+
scopes::StackMemberId,
12+
};
1013

1114
use super::ast_pass::AstPass;
1215
use crate::{
@@ -48,6 +51,75 @@ impl<'tc> SymbolResolutionPass<'tc> {
4851
}
4952
}
5053

54+
impl<'tc> SymbolResolutionPass<'tc> {
55+
/// Run a function for each stack member in the given pattern.
56+
///
57+
/// The stack members are found in the `AstInfo` store, specifically the
58+
/// `stack_members` map. They are looked up using the IDs of the pattern
59+
/// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
60+
/// `ScopeDiscoveryPass`.
61+
fn for_each_stack_member_of_pat(
62+
&self,
63+
node: ast::AstNodeRef<ast::Pat>,
64+
f: impl Fn(StackMemberId) + Copy,
65+
) {
66+
let for_spread_pat = |spread: &ast::AstNode<ast::SpreadPat>| {
67+
if let Some(name) = &spread.name {
68+
if let Some(member_id) =
69+
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
70+
{
71+
f(member_id);
72+
}
73+
}
74+
};
75+
match node.body() {
76+
ast::Pat::Binding(_) => {
77+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
78+
{
79+
f(member_id);
80+
}
81+
}
82+
ast::Pat::Tuple(tuple_pat) => {
83+
for (index, entry) in tuple_pat.fields.ast_ref_iter().enumerate() {
84+
if let Some(spread_node) = &tuple_pat.spread && spread_node.position == index {
85+
for_spread_pat(spread_node);
86+
}
87+
self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f);
88+
}
89+
}
90+
ast::Pat::Constructor(constructor_pat) => {
91+
for (index, field) in constructor_pat.fields.ast_ref_iter().enumerate() {
92+
if let Some(spread_node) = &constructor_pat.spread && spread_node.position == index {
93+
for_spread_pat(spread_node);
94+
}
95+
self.for_each_stack_member_of_pat(field.pat.ast_ref(), f);
96+
}
97+
}
98+
ast::Pat::List(list_pat) => {
99+
for (index, pat) in list_pat.fields.ast_ref_iter().enumerate() {
100+
if let Some(spread_node) = &list_pat.spread && spread_node.position == index {
101+
for_spread_pat(spread_node);
102+
}
103+
self.for_each_stack_member_of_pat(pat, f);
104+
}
105+
}
106+
ast::Pat::Or(or_pat) => {
107+
if let Some(pat) = or_pat.variants.get(0) {
108+
self.for_each_stack_member_of_pat(pat.ast_ref(), f)
109+
}
110+
}
111+
ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f),
112+
ast::Pat::Wild(_) => {
113+
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
114+
{
115+
f(member_id);
116+
}
117+
}
118+
ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {}
119+
}
120+
}
121+
}
122+
51123
/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
52124
/// next step is to resolve symbols in these scopes!.
53125
impl ast::AstVisitor for SymbolResolutionPass<'_> {
@@ -61,6 +133,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
61133
FnDef,
62134
TyFnDef,
63135
BodyBlock,
136+
MatchCase,
64137
);
65138

66139
type ModuleRet = ();
@@ -157,12 +230,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
157230
node: ast::AstNodeRef<ast::Declaration>,
158231
) -> Result<Self::DeclarationRet, Self::Error> {
159232
// If we are in a stack, then we need to add the declaration to the
160-
// stack's scope.
233+
// stack's scope. Otherwise the declaration is handled higher up.
161234
if let ScopeKind::Stack(_) = self.context().get_scope_kind() {
162-
let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap();
163-
self.context_ops().add_stack_binding(member);
235+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
236+
self.context_ops().add_stack_binding(member);
237+
});
164238
}
165239
walk::walk_declaration(self, node)?;
166240
Ok(())
167241
}
242+
243+
type MatchCaseRet = ();
244+
fn visit_match_case(
245+
&self,
246+
node: ast::AstNodeRef<ast::MatchCase>,
247+
) -> Result<Self::MatchCaseRet, Self::Error> {
248+
let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap();
249+
// Each match case has its own scope, so we need to enter it, and add all the
250+
// pattern bindings to the context.
251+
self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || {
252+
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
253+
self.context_ops().add_stack_binding(member);
254+
});
255+
walk::walk_match_case(self, node)?;
256+
Ok(())
257+
})
258+
}
168259
}

0 commit comments

Comments
 (0)