@@ -6,7 +6,10 @@ use hash_ast::{
6
6
ast, ast_visitor_default_impl,
7
7
visitor:: { walk, AstVisitor } ,
8
8
} ;
9
- use hash_types:: new:: environment:: { context:: ScopeKind , env:: AccessToEnv } ;
9
+ use hash_types:: new:: {
10
+ environment:: { context:: ScopeKind , env:: AccessToEnv } ,
11
+ scopes:: StackMemberId ,
12
+ } ;
10
13
11
14
use super :: ast_pass:: AstPass ;
12
15
use crate :: {
@@ -48,6 +51,75 @@ impl<'tc> SymbolResolutionPass<'tc> {
48
51
}
49
52
}
50
53
54
+ impl < ' tc > SymbolResolutionPass < ' tc > {
55
+ /// Run a function for each stack member in the given pattern.
56
+ ///
57
+ /// The stack members are found in the `AstInfo` store, specifically the
58
+ /// `stack_members` map. They are looked up using the IDs of the pattern
59
+ /// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
60
+ /// `ScopeDiscoveryPass`.
61
+ fn for_each_stack_member_of_pat (
62
+ & self ,
63
+ node : ast:: AstNodeRef < ast:: Pat > ,
64
+ f : impl Fn ( StackMemberId ) + Copy ,
65
+ ) {
66
+ let for_spread_pat = |spread : & ast:: AstNode < ast:: SpreadPat > | {
67
+ if let Some ( name) = & spread. name {
68
+ if let Some ( member_id) =
69
+ self . ast_info ( ) . stack_members ( ) . get_data_by_node ( name. ast_ref ( ) . id ( ) )
70
+ {
71
+ f ( member_id) ;
72
+ }
73
+ }
74
+ } ;
75
+ match node. body ( ) {
76
+ ast:: Pat :: Binding ( _) => {
77
+ if let Some ( member_id) = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. id ( ) )
78
+ {
79
+ f ( member_id) ;
80
+ }
81
+ }
82
+ ast:: Pat :: Tuple ( tuple_pat) => {
83
+ for ( index, entry) in tuple_pat. fields . ast_ref_iter ( ) . enumerate ( ) {
84
+ if let Some ( spread_node) = & tuple_pat. spread && spread_node. position == index {
85
+ for_spread_pat ( spread_node) ;
86
+ }
87
+ self . for_each_stack_member_of_pat ( entry. pat . ast_ref ( ) , f) ;
88
+ }
89
+ }
90
+ ast:: Pat :: Constructor ( constructor_pat) => {
91
+ for ( index, field) in constructor_pat. fields . ast_ref_iter ( ) . enumerate ( ) {
92
+ if let Some ( spread_node) = & constructor_pat. spread && spread_node. position == index {
93
+ for_spread_pat ( spread_node) ;
94
+ }
95
+ self . for_each_stack_member_of_pat ( field. pat . ast_ref ( ) , f) ;
96
+ }
97
+ }
98
+ ast:: Pat :: List ( list_pat) => {
99
+ for ( index, pat) in list_pat. fields . ast_ref_iter ( ) . enumerate ( ) {
100
+ if let Some ( spread_node) = & list_pat. spread && spread_node. position == index {
101
+ for_spread_pat ( spread_node) ;
102
+ }
103
+ self . for_each_stack_member_of_pat ( pat, f) ;
104
+ }
105
+ }
106
+ ast:: Pat :: Or ( or_pat) => {
107
+ if let Some ( pat) = or_pat. variants . get ( 0 ) {
108
+ self . for_each_stack_member_of_pat ( pat. ast_ref ( ) , f)
109
+ }
110
+ }
111
+ ast:: Pat :: If ( if_pat) => self . for_each_stack_member_of_pat ( if_pat. pat . ast_ref ( ) , f) ,
112
+ ast:: Pat :: Wild ( _) => {
113
+ if let Some ( member_id) = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. id ( ) )
114
+ {
115
+ f ( member_id) ;
116
+ }
117
+ }
118
+ ast:: Pat :: Module ( _) | ast:: Pat :: Access ( _) | ast:: Pat :: Lit ( _) | ast:: Pat :: Range ( _) => { }
119
+ }
120
+ }
121
+ }
122
+
51
123
/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
52
124
/// next step is to resolve symbols in these scopes!.
53
125
impl ast:: AstVisitor for SymbolResolutionPass < ' _ > {
@@ -61,6 +133,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
61
133
FnDef ,
62
134
TyFnDef ,
63
135
BodyBlock ,
136
+ MatchCase ,
64
137
) ;
65
138
66
139
type ModuleRet = ( ) ;
@@ -157,12 +230,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
157
230
node : ast:: AstNodeRef < ast:: Declaration > ,
158
231
) -> Result < Self :: DeclarationRet , Self :: Error > {
159
232
// If we are in a stack, then we need to add the declaration to the
160
- // stack's scope.
233
+ // stack's scope. Otherwise the declaration is handled higher up.
161
234
if let ScopeKind :: Stack ( _) = self . context ( ) . get_scope_kind ( ) {
162
- let member = self . ast_info ( ) . stack_members ( ) . get_data_by_node ( node. pat . id ( ) ) . unwrap ( ) ;
163
- self . context_ops ( ) . add_stack_binding ( member) ;
235
+ self . for_each_stack_member_of_pat ( node. pat . ast_ref ( ) , |member| {
236
+ self . context_ops ( ) . add_stack_binding ( member) ;
237
+ } ) ;
164
238
}
165
239
walk:: walk_declaration ( self , node) ?;
166
240
Ok ( ( ) )
167
241
}
242
+
243
+ type MatchCaseRet = ( ) ;
244
+ fn visit_match_case (
245
+ & self ,
246
+ node : ast:: AstNodeRef < ast:: MatchCase > ,
247
+ ) -> Result < Self :: MatchCaseRet , Self :: Error > {
248
+ let stack_id = self . ast_info ( ) . stacks ( ) . get_data_by_node ( node. id ( ) ) . unwrap ( ) ;
249
+ // Each match case has its own scope, so we need to enter it, and add all the
250
+ // pattern bindings to the context.
251
+ self . context_ops ( ) . enter_scope ( ScopeKind :: Stack ( stack_id) , || {
252
+ self . for_each_stack_member_of_pat ( node. pat . ast_ref ( ) , |member| {
253
+ self . context_ops ( ) . add_stack_binding ( member) ;
254
+ } ) ;
255
+ walk:: walk_match_case ( self , node) ?;
256
+ Ok ( ( ) )
257
+ } )
258
+ }
168
259
}
0 commit comments