Skip to content

Commit 5d5396d

Browse files
committed
Refactor operator overloading code into cc file
1 parent 3b3079e commit 5d5396d

File tree

2 files changed

+187
-185
lines changed

2 files changed

+187
-185
lines changed

gcc/rust/typecheck/rust-hir-type-check-expr.cc

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,5 +286,191 @@ TypeCheckExpr::visit (HIR::ArrayIndexExpr &expr)
286286
infered = array_type->get_element_type ()->clone ();
287287
}
288288

289+
bool
290+
TypeCheckExpr::resolve_operator_overload (
291+
Analysis::RustLangItem::ItemType lang_item_type, HIR::OperatorExpr &expr,
292+
TyTy::BaseType *lhs, TyTy::BaseType *rhs)
293+
{
294+
// look up lang item for arithmetic type
295+
std::string associated_item_name
296+
= Analysis::RustLangItem::ToString (lang_item_type);
297+
DefId respective_lang_item_id = UNKNOWN_DEFID;
298+
bool lang_item_defined
299+
= mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
300+
301+
// probe for the lang-item
302+
if (!lang_item_defined)
303+
return false;
304+
305+
auto segment = HIR::PathIdentSegment (associated_item_name);
306+
auto candidate
307+
= MethodResolver::Probe (lhs, HIR::PathIdentSegment (associated_item_name));
308+
309+
bool have_implementation_for_lang_item = !candidate.is_error ();
310+
if (!have_implementation_for_lang_item)
311+
return false;
312+
313+
// Get the adjusted self
314+
Adjuster adj (lhs);
315+
TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
316+
317+
// is this the case we are recursive
318+
// handle the case where we are within the impl block for this lang_item
319+
// otherwise we end up with a recursive operator overload such as the i32
320+
// operator overload trait
321+
TypeCheckContextItem &fn_context = context->peek_context ();
322+
if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
323+
{
324+
auto &impl_item = fn_context.get_impl_item ();
325+
HIR::ImplBlock *parent = impl_item.first;
326+
HIR::Function *fn = impl_item.second;
327+
328+
if (parent->has_trait_ref ()
329+
&& fn->get_function_name ().compare (associated_item_name) == 0)
330+
{
331+
TraitReference *trait_reference
332+
= TraitResolver::Lookup (*parent->get_trait_ref ().get ());
333+
if (!trait_reference->is_error ())
334+
{
335+
TyTy::BaseType *lookup = nullptr;
336+
bool ok = context->lookup_type (fn->get_mappings ().get_hirid (),
337+
&lookup);
338+
rust_assert (ok);
339+
rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
340+
341+
TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup);
342+
rust_assert (fntype->is_method ());
343+
344+
bool is_lang_item_impl
345+
= trait_reference->get_mappings ().get_defid ()
346+
== respective_lang_item_id;
347+
bool self_is_lang_item_self
348+
= fntype->get_self_type ()->is_equal (*adjusted_self);
349+
bool recursive_operator_overload
350+
= is_lang_item_impl && self_is_lang_item_self;
351+
352+
if (recursive_operator_overload)
353+
return false;
354+
}
355+
}
356+
}
357+
358+
// store the adjustments for code-generation to know what to do
359+
context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (),
360+
std::move (candidate.adjustments));
361+
362+
// now its just like a method-call-expr
363+
context->insert_receiver (expr.get_mappings ().get_hirid (), lhs);
364+
365+
PathProbeCandidate &resolved_candidate = candidate.candidate;
366+
TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
367+
NodeId resolved_node_id
368+
= resolved_candidate.is_impl_candidate ()
369+
? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
370+
.get_nodeid ()
371+
: resolved_candidate.item.trait.item_ref->get_mappings ().get_nodeid ();
372+
373+
rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
374+
TyTy::BaseType *lookup = lookup_tyty;
375+
TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
376+
rust_assert (fn->is_method ());
377+
378+
auto root = lhs->get_root ();
379+
bool receiver_is_type_param = root->get_kind () == TyTy::TypeKind::PARAM;
380+
if (root->get_kind () == TyTy::TypeKind::ADT)
381+
{
382+
const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root);
383+
if (adt->has_substitutions () && fn->needs_substitution ())
384+
{
385+
// consider the case where we have:
386+
//
387+
// struct Foo<X,Y>(X,Y);
388+
//
389+
// impl<T> Foo<T, i32> {
390+
// fn test<X>(self, a:X) -> (T,X) { (self.0, a) }
391+
// }
392+
//
393+
// In this case we end up with an fn type of:
394+
//
395+
// fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X)
396+
//
397+
// This means the instance or self we are calling this method for
398+
// will be substituted such that we can get the inherited type
399+
// arguments but then need to use the turbo fish if available or
400+
// infer the remaining arguments. Luckily rust does not allow for
401+
// default types GenericParams on impl blocks since these must
402+
// always be at the end of the list
403+
404+
auto s = fn->get_self_type ()->get_root ();
405+
rust_assert (s->can_eq (adt, false));
406+
rust_assert (s->get_kind () == TyTy::TypeKind::ADT);
407+
const TyTy::ADTType *self_adt
408+
= static_cast<const TyTy::ADTType *> (s);
409+
410+
// we need to grab the Self substitutions as the inherit type
411+
// parameters for this
412+
if (self_adt->needs_substitution ())
413+
{
414+
rust_assert (adt->was_substituted ());
415+
416+
TyTy::SubstitutionArgumentMappings used_args_in_prev_segment
417+
= GetUsedSubstArgs::From (adt);
418+
419+
TyTy::SubstitutionArgumentMappings inherit_type_args
420+
= self_adt->solve_mappings_from_receiver_for_self (
421+
used_args_in_prev_segment);
422+
423+
// there may or may not be inherited type arguments
424+
if (!inherit_type_args.is_error ())
425+
{
426+
// need to apply the inherited type arguments to the
427+
// function
428+
lookup = fn->handle_substitions (inherit_type_args);
429+
}
430+
}
431+
}
432+
}
433+
434+
// handle generics
435+
if (!receiver_is_type_param)
436+
{
437+
if (lookup->needs_generic_substitutions ())
438+
{
439+
lookup = SubstMapper::InferSubst (lookup, expr.get_locus ());
440+
}
441+
}
442+
443+
// type check the arguments if required
444+
TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup);
445+
rust_assert (type->num_params () > 0);
446+
auto fnparam = type->param_at (0);
447+
fnparam.second->unify (adjusted_self); // typecheck the self
448+
if (rhs == nullptr)
449+
{
450+
rust_assert (type->num_params () == 1);
451+
}
452+
else
453+
{
454+
rust_assert (type->num_params () == 2);
455+
auto fnparam = type->param_at (1);
456+
fnparam.second->unify (rhs); // typecheck the rhs
457+
}
458+
459+
// get the return type
460+
TyTy::BaseType *function_ret_tyty = type->get_return_type ()->clone ();
461+
462+
// store the expected fntype
463+
context->insert_operator_overload (expr.get_mappings ().get_hirid (), type);
464+
465+
// set up the resolved name on the path
466+
resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (),
467+
resolved_node_id);
468+
469+
// return the result of the function back
470+
infered = function_ret_tyty;
471+
472+
return true;
473+
}
474+
289475
} // namespace Resolver
290476
} // namespace Rust

gcc/rust/typecheck/rust-hir-type-check-expr.h

Lines changed: 1 addition & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,191 +1240,7 @@ class TypeCheckExpr : public TypeCheckBase
12401240
bool
12411241
resolve_operator_overload (Analysis::RustLangItem::ItemType lang_item_type,
12421242
HIR::OperatorExpr &expr, TyTy::BaseType *lhs,
1243-
TyTy::BaseType *rhs)
1244-
{
1245-
// look up lang item for arithmetic type
1246-
std::string associated_item_name
1247-
= Analysis::RustLangItem::ToString (lang_item_type);
1248-
DefId respective_lang_item_id = UNKNOWN_DEFID;
1249-
bool lang_item_defined
1250-
= mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
1251-
1252-
// probe for the lang-item
1253-
if (!lang_item_defined)
1254-
return false;
1255-
1256-
auto segment = HIR::PathIdentSegment (associated_item_name);
1257-
auto candidate
1258-
= MethodResolver::Probe (lhs,
1259-
HIR::PathIdentSegment (associated_item_name));
1260-
1261-
bool have_implementation_for_lang_item = !candidate.is_error ();
1262-
if (!have_implementation_for_lang_item)
1263-
return false;
1264-
1265-
// Get the adjusted self
1266-
Adjuster adj (lhs);
1267-
TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
1268-
1269-
// is this the case we are recursive
1270-
// handle the case where we are within the impl block for this lang_item
1271-
// otherwise we end up with a recursive operator overload such as the i32
1272-
// operator overload trait
1273-
TypeCheckContextItem &fn_context = context->peek_context ();
1274-
if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
1275-
{
1276-
auto &impl_item = fn_context.get_impl_item ();
1277-
HIR::ImplBlock *parent = impl_item.first;
1278-
HIR::Function *fn = impl_item.second;
1279-
1280-
if (parent->has_trait_ref ()
1281-
&& fn->get_function_name ().compare (associated_item_name) == 0)
1282-
{
1283-
TraitReference *trait_reference
1284-
= TraitResolver::Lookup (*parent->get_trait_ref ().get ());
1285-
if (!trait_reference->is_error ())
1286-
{
1287-
TyTy::BaseType *lookup = nullptr;
1288-
bool ok
1289-
= context->lookup_type (fn->get_mappings ().get_hirid (),
1290-
&lookup);
1291-
rust_assert (ok);
1292-
rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
1293-
1294-
TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup);
1295-
rust_assert (fntype->is_method ());
1296-
1297-
bool is_lang_item_impl
1298-
= trait_reference->get_mappings ().get_defid ()
1299-
== respective_lang_item_id;
1300-
bool self_is_lang_item_self
1301-
= fntype->get_self_type ()->is_equal (*adjusted_self);
1302-
bool recursive_operator_overload
1303-
= is_lang_item_impl && self_is_lang_item_self;
1304-
1305-
if (recursive_operator_overload)
1306-
return false;
1307-
}
1308-
}
1309-
}
1310-
1311-
// store the adjustments for code-generation to know what to do
1312-
context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (),
1313-
std::move (candidate.adjustments));
1314-
1315-
// now its just like a method-call-expr
1316-
context->insert_receiver (expr.get_mappings ().get_hirid (), lhs);
1317-
1318-
PathProbeCandidate &resolved_candidate = candidate.candidate;
1319-
TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
1320-
NodeId resolved_node_id
1321-
= resolved_candidate.is_impl_candidate ()
1322-
? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
1323-
.get_nodeid ()
1324-
: resolved_candidate.item.trait.item_ref->get_mappings ()
1325-
.get_nodeid ();
1326-
1327-
rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
1328-
TyTy::BaseType *lookup = lookup_tyty;
1329-
TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
1330-
rust_assert (fn->is_method ());
1331-
1332-
auto root = lhs->get_root ();
1333-
bool receiver_is_type_param = root->get_kind () == TyTy::TypeKind::PARAM;
1334-
if (root->get_kind () == TyTy::TypeKind::ADT)
1335-
{
1336-
const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root);
1337-
if (adt->has_substitutions () && fn->needs_substitution ())
1338-
{
1339-
// consider the case where we have:
1340-
//
1341-
// struct Foo<X,Y>(X,Y);
1342-
//
1343-
// impl<T> Foo<T, i32> {
1344-
// fn test<X>(self, a:X) -> (T,X) { (self.0, a) }
1345-
// }
1346-
//
1347-
// In this case we end up with an fn type of:
1348-
//
1349-
// fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X)
1350-
//
1351-
// This means the instance or self we are calling this method for
1352-
// will be substituted such that we can get the inherited type
1353-
// arguments but then need to use the turbo fish if available or
1354-
// infer the remaining arguments. Luckily rust does not allow for
1355-
// default types GenericParams on impl blocks since these must
1356-
// always be at the end of the list
1357-
1358-
auto s = fn->get_self_type ()->get_root ();
1359-
rust_assert (s->can_eq (adt, false));
1360-
rust_assert (s->get_kind () == TyTy::TypeKind::ADT);
1361-
const TyTy::ADTType *self_adt
1362-
= static_cast<const TyTy::ADTType *> (s);
1363-
1364-
// we need to grab the Self substitutions as the inherit type
1365-
// parameters for this
1366-
if (self_adt->needs_substitution ())
1367-
{
1368-
rust_assert (adt->was_substituted ());
1369-
1370-
TyTy::SubstitutionArgumentMappings used_args_in_prev_segment
1371-
= GetUsedSubstArgs::From (adt);
1372-
1373-
TyTy::SubstitutionArgumentMappings inherit_type_args
1374-
= self_adt->solve_mappings_from_receiver_for_self (
1375-
used_args_in_prev_segment);
1376-
1377-
// there may or may not be inherited type arguments
1378-
if (!inherit_type_args.is_error ())
1379-
{
1380-
// need to apply the inherited type arguments to the
1381-
// function
1382-
lookup = fn->handle_substitions (inherit_type_args);
1383-
}
1384-
}
1385-
}
1386-
}
1387-
1388-
// handle generics
1389-
if (!receiver_is_type_param)
1390-
{
1391-
if (lookup->needs_generic_substitutions ())
1392-
{
1393-
lookup = SubstMapper::InferSubst (lookup, expr.get_locus ());
1394-
}
1395-
}
1396-
1397-
// type check the arguments if required
1398-
TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup);
1399-
rust_assert (type->num_params () > 0);
1400-
auto fnparam = type->param_at (0);
1401-
fnparam.second->unify (adjusted_self); // typecheck the self
1402-
if (rhs == nullptr)
1403-
{
1404-
rust_assert (type->num_params () == 1);
1405-
}
1406-
else
1407-
{
1408-
rust_assert (type->num_params () == 2);
1409-
auto fnparam = type->param_at (1);
1410-
fnparam.second->unify (rhs); // typecheck the rhs
1411-
}
1412-
1413-
// get the return type
1414-
TyTy::BaseType *function_ret_tyty = type->get_return_type ()->clone ();
1415-
1416-
// store the expected fntype
1417-
context->insert_operator_overload (expr.get_mappings ().get_hirid (), type);
1418-
1419-
// set up the resolved name on the path
1420-
resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (),
1421-
resolved_node_id);
1422-
1423-
// return the result of the function back
1424-
infered = function_ret_tyty;
1425-
1426-
return true;
1427-
}
1243+
TyTy::BaseType *rhs);
14281244

14291245
private:
14301246
TypeCheckExpr (bool inside_loop)

0 commit comments

Comments
 (0)