@@ -6,7 +6,10 @@ use hash_ast::{
6
6
ast, ast_visitor_default_impl,
7
7
visitor:: { walk, AstVisitor } ,
8
8
} ;
9
- use hash_types:: new:: environment:: { context:: ScopeKind , env:: AccessToEnv } ;
9
+ use hash_types:: new:: {
10
+ environment:: { context:: ScopeKind , env:: AccessToEnv } ,
11
+ scopes:: StackMemberId ,
12
+ } ;
10
13
11
14
use super :: ast_pass:: AstPass ;
12
15
use crate :: {
@@ -48,6 +51,66 @@ impl<'tc> SymbolResolutionPass<'tc> {
48
51
}
49
52
}
50
53
54
+ impl < ' tc > SymbolResolutionPass < ' tc > {
55
+ /// Run a function for each stack member in the given pattern.
56
+ ///
57
+ /// The stack members are found in the `AstInfo` store, specifically the
58
+ /// `stack_members` map. They are looked up using the IDs of the pattern
59
+ /// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
60
+ /// `ScopeDiscoveryPass`.
61
+ fn for_each_stack_member_of_pat (
62
+ & self ,
63
+ node : ast:: AstNodeRef < ast:: Pat > ,
64
+ f : impl Fn ( StackMemberId ) + Copy ,
65
+ ) {
66
+ match node. body ( ) {
67
+ ast:: Pat :: Binding ( _) => {
68
+ if let Some ( member_id) = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. id ( ) )
69
+ {
70
+ f ( member_id) ;
71
+ }
72
+ }
73
+ ast:: Pat :: Tuple ( tuple_pat) => {
74
+ for entry in tuple_pat. fields . ast_ref_iter ( ) {
75
+ self . for_each_stack_member_of_pat ( entry. pat . ast_ref ( ) , f) ;
76
+ }
77
+ }
78
+ ast:: Pat :: Constructor ( constructor_pat) => {
79
+ for field in constructor_pat. fields . ast_ref_iter ( ) {
80
+ self . for_each_stack_member_of_pat ( field. pat . ast_ref ( ) , f) ;
81
+ }
82
+ }
83
+ ast:: Pat :: List ( list_pat) => {
84
+ for pat in list_pat. fields . ast_ref_iter ( ) {
85
+ self . for_each_stack_member_of_pat ( pat, f) ;
86
+ }
87
+ }
88
+ ast:: Pat :: Or ( or_pat) => {
89
+ if let Some ( pat) = or_pat. variants . get ( 0 ) {
90
+ self . for_each_stack_member_of_pat ( pat. ast_ref ( ) , f)
91
+ }
92
+ }
93
+ ast:: Pat :: Spread ( spread_pat) => {
94
+ if let Some ( name) = spread_pat. name . as_ref ( ) {
95
+ if let Some ( member_id) =
96
+ self . ast_info ( ) . stack_members ( ) . get_data_by_node ( name. ast_ref ( ) . id ( ) )
97
+ {
98
+ f ( member_id) ;
99
+ }
100
+ }
101
+ }
102
+ ast:: Pat :: If ( if_pat) => self . for_each_stack_member_of_pat ( if_pat. pat . ast_ref ( ) , f) ,
103
+ ast:: Pat :: Wild ( _) => {
104
+ if let Some ( member_id) = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. id ( ) )
105
+ {
106
+ f ( member_id) ;
107
+ }
108
+ }
109
+ ast:: Pat :: Module ( _) | ast:: Pat :: Access ( _) | ast:: Pat :: Lit ( _) | ast:: Pat :: Range ( _) => { }
110
+ }
111
+ }
112
+ }
113
+
51
114
/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
52
115
/// next step is to resolve symbols in these scopes!.
53
116
impl ast:: AstVisitor for SymbolResolutionPass < ' _ > {
@@ -61,6 +124,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
61
124
FnDef ,
62
125
TyFnDef ,
63
126
BodyBlock ,
127
+ MatchCase ,
64
128
) ;
65
129
66
130
type ModuleRet = ( ) ;
@@ -157,12 +221,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
157
221
node : ast:: AstNodeRef < ast:: Declaration > ,
158
222
) -> Result < Self :: DeclarationRet , Self :: Error > {
159
223
// If we are in a stack, then we need to add the declaration to the
160
- // stack's scope.
224
+ // stack's scope. Otherwise the declaration is handled higher up.
161
225
if let ScopeKind :: Stack ( _) = self . context ( ) . get_scope_kind ( ) {
162
- let member = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. pat . id ( ) ) . unwrap ( ) ;
163
- self . context_ops ( ) . add_stack_binding ( member) ;
226
+ self . for_each_stack_member_of_pat ( node. pat . ast_ref ( ) , |member| {
227
+ self . context_ops ( ) . add_stack_binding ( member) ;
228
+ } ) ;
164
229
}
165
230
walk:: walk_declaration ( self , node) ?;
166
231
Ok ( ( ) )
167
232
}
233
+
234
+ type MatchCaseRet = ( ) ;
235
+ fn visit_match_case (
236
+ & self ,
237
+ node : ast:: AstNodeRef < ast:: MatchCase > ,
238
+ ) -> Result < Self :: MatchCaseRet , Self :: Error > {
239
+ let stack_id = self . ast_info ( ) . stacks ( ) . get_data_by_node ( node. id ( ) ) . unwrap ( ) ;
240
+ // Each match case has its own scope, so we need to enter it, and add all the
241
+ // pattern bindings to the context.
242
+ self . context_ops ( ) . enter_scope ( ScopeKind :: Stack ( stack_id) , || {
243
+ self . for_each_stack_member_of_pat ( node. pat . ast_ref ( ) , |member| {
244
+ self . context_ops ( ) . add_stack_binding ( member) ;
245
+ } ) ;
246
+ walk:: walk_match_case ( self , node) ?;
247
+ Ok ( ( ) )
248
+ } )
249
+ }
168
250
}
0 commit comments