Skip to content

Commit 92a434a

Browse files
committed
Add support for nested functions
We missed that stmts in rust can be items like functions. This adds support for resolution and compilation of nested functions. Rust allows nested functions which are distinct to closures. Nested functions are not allowed to encapsulate the enclosing scope so they can be extracted as normal functions.
1 parent 23e748d commit 92a434a

11 files changed

+299
-74
lines changed

gcc/rust/backend/rust-compile-base.h

+3
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ class HIRCompileBase : public HIR::HIRVisitor
210210
void compile_function_body (Bfunction *fndecl,
211211
std::unique_ptr<HIR::BlockExpr> &function_body,
212212
bool has_return_type);
213+
214+
bool compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl,
215+
std::vector<Bvariable *> &locals);
213216
};
214217

215218
} // namespace Compile

gcc/rust/backend/rust-compile-implitem.h

+6-38
Original file line numberDiff line numberDiff line change
@@ -183,26 +183,10 @@ class CompileInherentImplItem : public HIRCompileBase
183183
}
184184

185185
std::vector<Bvariable *> locals;
186-
rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
187-
Resolver::Definition d;
188-
bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
189-
rust_assert (ok);
190-
191-
HIR::Stmt *decl = nullptr;
192-
ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
193-
rust_assert (ok);
194-
195-
Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
196-
locals.push_back (compiled);
197-
198-
return true;
199-
});
200-
201-
bool toplevel_item
202-
= function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
203-
Bblock *enclosing_scope
204-
= toplevel_item ? NULL : ctx->peek_enclosing_scope ();
186+
bool ok = compile_locals_for_block (*rib, fndecl, locals);
187+
rust_assert (ok);
205188

189+
Bblock *enclosing_scope = NULL;
206190
HIR::BlockExpr *function_body = function.get_definition ().get ();
207191
Location start_location = function_body->get_locus ();
208192
Location end_location = function_body->get_closing_locus ();
@@ -409,26 +393,10 @@ class CompileInherentImplItem : public HIRCompileBase
409393
}
410394

411395
std::vector<Bvariable *> locals;
412-
rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
413-
Resolver::Definition d;
414-
bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
415-
rust_assert (ok);
416-
417-
HIR::Stmt *decl = nullptr;
418-
ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
419-
rust_assert (ok);
420-
421-
Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
422-
locals.push_back (compiled);
423-
424-
return true;
425-
});
426-
427-
bool toplevel_item
428-
= method.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
429-
Bblock *enclosing_scope
430-
= toplevel_item ? NULL : ctx->peek_enclosing_scope ();
396+
bool ok = compile_locals_for_block (*rib, fndecl, locals);
397+
rust_assert (ok);
431398

399+
Bblock *enclosing_scope = NULL;
432400
HIR::BlockExpr *function_body = method.get_function_body ().get ();
433401
Location start_location = function_body->get_locus ();
434402
Location end_location = function_body->get_closing_locus ();

gcc/rust/backend/rust-compile-item.h

+3-19
Original file line numberDiff line numberDiff line change
@@ -213,26 +213,10 @@ class CompileItem : public HIRCompileBase
213213
}
214214

215215
std::vector<Bvariable *> locals;
216-
rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
217-
Resolver::Definition d;
218-
bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
219-
rust_assert (ok);
220-
221-
HIR::Stmt *decl = nullptr;
222-
ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
223-
rust_assert (ok);
224-
225-
Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
226-
locals.push_back (compiled);
227-
228-
return true;
229-
});
230-
231-
bool toplevel_item
232-
= function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
233-
Bblock *enclosing_scope
234-
= toplevel_item ? NULL : ctx->peek_enclosing_scope ();
216+
bool ok = compile_locals_for_block (*rib, fndecl, locals);
217+
rust_assert (ok);
235218

219+
Bblock *enclosing_scope = NULL;
236220
HIR::BlockExpr *function_body = function.get_definition ().get ();
237221
Location start_location = function_body->get_locus ();
238222
Location end_location = function_body->get_closing_locus ();

gcc/rust/backend/rust-compile.cc

+37-14
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,8 @@ CompileBlock::visit (HIR::BlockExpr &expr)
212212
}
213213

214214
std::vector<Bvariable *> locals;
215-
rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
216-
Resolver::Definition d;
217-
bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
218-
rust_assert (ok);
219-
220-
HIR::Stmt *decl = nullptr;
221-
ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
222-
rust_assert (ok);
223-
224-
Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
225-
locals.push_back (compiled);
226-
227-
return true;
228-
});
215+
bool ok = compile_locals_for_block (*rib, fndecl, locals);
216+
rust_assert (ok);
229217

230218
Bblock *enclosing_scope = ctx->peek_enclosing_scope ();
231219
Bblock *new_block
@@ -415,6 +403,41 @@ HIRCompileBase::compile_function_body (
415403
}
416404
}
417405

406+
bool
407+
HIRCompileBase::compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl,
408+
std::vector<Bvariable *> &locals)
409+
{
410+
rib.iterate_decls ([&] (NodeId n, Location) mutable -> bool {
411+
Resolver::Definition d;
412+
bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
413+
rust_assert (ok);
414+
415+
HIR::Stmt *decl = nullptr;
416+
ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
417+
rust_assert (ok);
418+
419+
// if its a function we extract this out side of this fn context
420+
// and it is not a local to this function
421+
bool is_item = ctx->get_mappings ()->lookup_hir_item (
422+
decl->get_mappings ().get_crate_num (),
423+
decl->get_mappings ().get_hirid ())
424+
!= nullptr;
425+
if (is_item)
426+
{
427+
HIR::Item *item = static_cast<HIR::Item *> (decl);
428+
CompileItem::compile (item, ctx, true);
429+
return true;
430+
}
431+
432+
Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
433+
locals.push_back (compiled);
434+
435+
return true;
436+
});
437+
438+
return true;
439+
}
440+
418441
// Mr Mangle time
419442

420443
static const std::string kMangledSymbolPrefix = "_ZN";

gcc/rust/hir/rust-ast-lower-stmt.h

+85
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,91 @@ class ASTLoweringStmt : public ASTLoweringBase
230230
empty.get_locus ());
231231
}
232232

233+
void visit (AST::Function &function) override
234+
{
235+
// ignore for now and leave empty
236+
std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items;
237+
HIR::WhereClause where_clause (std::move (where_clause_items));
238+
HIR::FunctionQualifiers qualifiers (
239+
HIR::FunctionQualifiers::AsyncConstStatus::NONE, false);
240+
HIR::Visibility vis = HIR::Visibility::create_public ();
241+
242+
// need
243+
std::vector<std::unique_ptr<HIR::GenericParam> > generic_params;
244+
if (function.has_generics ())
245+
{
246+
generic_params = lower_generic_params (function.get_generic_params ());
247+
}
248+
249+
Identifier function_name = function.get_function_name ();
250+
Location locus = function.get_locus ();
251+
252+
std::unique_ptr<HIR::Type> return_type
253+
= function.has_return_type () ? std::unique_ptr<HIR::Type> (
254+
ASTLoweringType::translate (function.get_return_type ().get ()))
255+
: nullptr;
256+
257+
std::vector<HIR::FunctionParam> function_params;
258+
for (auto &param : function.get_function_params ())
259+
{
260+
auto translated_pattern = std::unique_ptr<HIR::Pattern> (
261+
ASTLoweringPattern::translate (param.get_pattern ().get ()));
262+
auto translated_type = std::unique_ptr<HIR::Type> (
263+
ASTLoweringType::translate (param.get_type ().get ()));
264+
265+
auto crate_num = mappings->get_current_crate ();
266+
Analysis::NodeMapping mapping (crate_num, param.get_node_id (),
267+
mappings->get_next_hir_id (crate_num),
268+
UNKNOWN_LOCAL_DEFID);
269+
270+
auto hir_param
271+
= HIR::FunctionParam (mapping, std::move (translated_pattern),
272+
std::move (translated_type),
273+
param.get_locus ());
274+
function_params.push_back (hir_param);
275+
}
276+
277+
bool terminated = false;
278+
std::unique_ptr<HIR::BlockExpr> function_body
279+
= std::unique_ptr<HIR::BlockExpr> (
280+
ASTLoweringBlock::translate (function.get_definition ().get (),
281+
&terminated));
282+
283+
auto crate_num = mappings->get_current_crate ();
284+
Analysis::NodeMapping mapping (crate_num, function.get_node_id (),
285+
mappings->get_next_hir_id (crate_num),
286+
UNKNOWN_LOCAL_DEFID);
287+
288+
mappings->insert_location (crate_num,
289+
function_body->get_mappings ().get_hirid (),
290+
function.get_locus ());
291+
292+
auto fn
293+
= new HIR::Function (mapping, std::move (function_name),
294+
std::move (qualifiers), std::move (generic_params),
295+
std::move (function_params), std::move (return_type),
296+
std::move (where_clause), std::move (function_body),
297+
std::move (vis), function.get_outer_attrs (), locus);
298+
299+
mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (),
300+
fn);
301+
mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (),
302+
fn);
303+
mappings->insert_location (crate_num, mapping.get_hirid (),
304+
function.get_locus ());
305+
306+
// add the mappings for the function params at the end
307+
for (auto &param : fn->get_function_params ())
308+
{
309+
mappings->insert_hir_param (mapping.get_crate_num (),
310+
param.get_mappings ().get_hirid (), &param);
311+
mappings->insert_location (crate_num, mapping.get_hirid (),
312+
param.get_locus ());
313+
}
314+
315+
translated = fn;
316+
}
317+
233318
private:
234319
ASTLoweringStmt () : translated (nullptr), terminated (false) {}
235320

gcc/rust/resolve/rust-ast-resolve-stmt.h

+55
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,61 @@ class ResolveStmt : public ResolverBase
129129
resolver->get_type_scope ().pop ();
130130
}
131131

132+
void visit (AST::Function &function) override
133+
{
134+
auto path = ResolveFunctionItemToCanonicalPath::resolve (function);
135+
resolver->get_name_scope ().insert (
136+
path, function.get_node_id (), function.get_locus (), false,
137+
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
138+
RichLocation r (function.get_locus ());
139+
r.add_range (locus);
140+
rust_error_at (r, "redefined multiple times");
141+
});
142+
resolver->insert_new_definition (function.get_node_id (),
143+
Definition{function.get_node_id (),
144+
function.get_node_id ()});
145+
146+
NodeId scope_node_id = function.get_node_id ();
147+
resolver->get_name_scope ().push (scope_node_id);
148+
resolver->get_type_scope ().push (scope_node_id);
149+
resolver->get_label_scope ().push (scope_node_id);
150+
resolver->push_new_name_rib (resolver->get_name_scope ().peek ());
151+
resolver->push_new_type_rib (resolver->get_type_scope ().peek ());
152+
resolver->push_new_label_rib (resolver->get_type_scope ().peek ());
153+
154+
if (function.has_generics ())
155+
{
156+
for (auto &generic : function.get_generic_params ())
157+
ResolveGenericParam::go (generic.get (), function.get_node_id ());
158+
}
159+
160+
if (function.has_return_type ())
161+
ResolveType::go (function.get_return_type ().get (),
162+
function.get_node_id ());
163+
164+
// we make a new scope so the names of parameters are resolved and shadowed
165+
// correctly
166+
for (auto &param : function.get_function_params ())
167+
{
168+
ResolveType::go (param.get_type ().get (), param.get_node_id ());
169+
PatternDeclaration::go (param.get_pattern ().get (),
170+
param.get_node_id ());
171+
172+
// the mutability checker needs to verify for immutable decls the number
173+
// of assignments are <1. This marks an implicit assignment
174+
resolver->mark_assignment_to_decl (param.get_pattern ()->get_node_id (),
175+
param.get_node_id ());
176+
}
177+
178+
// resolve the function body
179+
ResolveExpr::go (function.get_definition ().get (),
180+
function.get_node_id ());
181+
182+
resolver->get_name_scope ().pop ();
183+
resolver->get_type_scope ().pop ();
184+
resolver->get_label_scope ().pop ();
185+
}
186+
132187
private:
133188
ResolveStmt (NodeId parent) : ResolverBase (parent) {}
134189
};

gcc/rust/resolve/rust-ast-resolve.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,8 @@ ResolvePath::resolve_path (AST::PathInExpression *expr)
499499
else
500500
{
501501
rust_error_at (expr->get_locus (),
502-
"unknown root segment in path %s lookup %s",
503-
expr->as_string ().c_str (),
504-
root_ident_seg.as_string ().c_str ());
502+
"Cannot find path %<%s%> in this scope",
503+
expr->as_string ().c_str ());
505504
return;
506505
}
507506

0 commit comments

Comments
 (0)