Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gccrs: cleanup our enum type layout to be closer to rustc #3357

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,24 +538,32 @@ CompileExpr::visit (HIR::StructExprStructFields &struct_expr)
}
}

// the constructor depends on whether this is actually an enum or not if
// its an enum we need to setup the discriminator
std::vector<tree> ctor_arguments;
if (adt->is_enum ())
if (!adt->is_enum ())
{
HIR::Expr &discrim_expr = variant->get_discriminant ();
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);
tree folded_discrim_expr = fold_expr (discrim_expr_node);
tree qualifier = folded_discrim_expr;

ctor_arguments.push_back (qualifier);
translated
= Backend::constructor_expression (compiled_adt_type, adt->is_enum (),
arguments, union_disriminator,
struct_expr.get_locus ());
return;
}
for (auto &arg : arguments)
ctor_arguments.push_back (arg);

HIR::Expr &discrim_expr = variant->get_discriminant ();
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);
tree folded_discrim_expr = fold_expr (discrim_expr_node);
tree qualifier = folded_discrim_expr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need that variable ? Could we use folded_discrim_expr directly or rename it ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good find will fix


tree enum_root_files = TYPE_FIELDS (compiled_adt_type);
tree payload_root = DECL_CHAIN (enum_root_files);

tree payload = Backend::constructor_expression (TREE_TYPE (payload_root),
adt->is_enum (), arguments,
union_disriminator,
struct_expr.get_locus ());

std::vector<tree> ctor_arguments = {qualifier, payload};

translated
= Backend::constructor_expression (compiled_adt_type, adt->is_enum (),
ctor_arguments, union_disriminator,
= Backend::constructor_expression (compiled_adt_type, 0, ctor_arguments, -1,
struct_expr.get_locus ());
}

Expand Down Expand Up @@ -1227,26 +1235,34 @@ CompileExpr::visit (HIR::CallExpr &expr)
arguments.push_back (rvalue);
}

// the constructor depends on whether this is actually an enum or not if
// its an enum we need to setup the discriminator
std::vector<tree> ctor_arguments;
if (adt->is_enum ())
if (!adt->is_enum ())
{
HIR::Expr &discrim_expr = variant->get_discriminant ();
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);
tree folded_discrim_expr = fold_expr (discrim_expr_node);
tree qualifier = folded_discrim_expr;

ctor_arguments.push_back (qualifier);
translated
= Backend::constructor_expression (compiled_adt_type,
adt->is_enum (), arguments,
union_disriminator,
expr.get_locus ());
return;
}
for (auto &arg : arguments)
ctor_arguments.push_back (arg);

translated
= Backend::constructor_expression (compiled_adt_type, adt->is_enum (),
ctor_arguments, union_disriminator,
HIR::Expr &discrim_expr = variant->get_discriminant ();
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);
tree folded_discrim_expr = fold_expr (discrim_expr_node);
tree qualifier = folded_discrim_expr;

tree enum_root_files = TYPE_FIELDS (compiled_adt_type);
tree payload_root = DECL_CHAIN (enum_root_files);

tree payload
= Backend::constructor_expression (TREE_TYPE (payload_root), true,
{arguments}, union_disriminator,
expr.get_locus ());

std::vector<tree> ctor_arguments = {qualifier, payload};
translated = Backend::constructor_expression (compiled_adt_type, false,
ctor_arguments, -1,
expr.get_locus ());

return;
}

Expand Down
64 changes: 31 additions & 33 deletions gcc/rust/backend/rust-compile-pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "rust-compile-resolve-path.h"
#include "rust-constexpr.h"
#include "rust-compile-type.h"
#include "print-tree.h"

namespace Rust {
namespace Compile {
Expand Down Expand Up @@ -57,11 +58,8 @@ CompilePatternCheckExpr::visit (HIR::PathInExpression &pattern)
rust_assert (ok);

// find discriminant field of scrutinee
tree scrutinee_record_expr
= Backend::struct_field_expression (match_scrutinee_expr, 0,
pattern.get_locus ());
tree scrutinee_expr_qualifier_expr
= Backend::struct_field_expression (scrutinee_record_expr, 0,
= Backend::struct_field_expression (match_scrutinee_expr, 0,
pattern.get_locus ());

// must be enum
Expand Down Expand Up @@ -227,11 +225,8 @@ CompilePatternCheckExpr::visit (HIR::StructPattern &pattern)
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);

// find discriminant field of scrutinee
tree scrutinee_record_expr
= Backend::struct_field_expression (match_scrutinee_expr, variant_index,
pattern.get_path ().get_locus ());
tree scrutinee_expr_qualifier_expr
= Backend::struct_field_expression (scrutinee_record_expr, 0,
= Backend::struct_field_expression (match_scrutinee_expr, 0,
pattern.get_path ().get_locus ());

check_expr
Expand All @@ -240,7 +235,7 @@ CompilePatternCheckExpr::visit (HIR::StructPattern &pattern)
discrim_expr_node,
pattern.get_path ().get_locus ());

match_scrutinee_expr = scrutinee_record_expr;
match_scrutinee_expr = scrutinee_expr_qualifier_expr;
}
else
{
Expand Down Expand Up @@ -295,8 +290,6 @@ CompilePatternCheckExpr::visit (HIR::StructPattern &pattern)
void
CompilePatternCheckExpr::visit (HIR::TupleStructPattern &pattern)
{
size_t tuple_field_index;

// lookup the type
TyTy::BaseType *lookup = nullptr;
bool ok = ctx->get_tyctx ()->lookup_type (
Expand All @@ -307,6 +300,7 @@ CompilePatternCheckExpr::visit (HIR::TupleStructPattern &pattern)
rust_assert (lookup->get_kind () == TyTy::TypeKind::ADT);
TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (lookup);

int variant_index = 0;
rust_assert (adt->number_of_variants () > 0);
TyTy::VariantDef *variant = nullptr;
if (adt->is_enum ())
Expand All @@ -317,7 +311,6 @@ CompilePatternCheckExpr::visit (HIR::TupleStructPattern &pattern)
pattern.get_path ().get_mappings ().get_hirid (), &variant_id);
rust_assert (ok);

int variant_index = 0;
ok = adt->lookup_variant_by_id (variant_id, &variant, &variant_index);
rust_assert (ok);

Expand All @@ -326,29 +319,20 @@ CompilePatternCheckExpr::visit (HIR::TupleStructPattern &pattern)
tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);

// find discriminant field of scrutinee
tree scrutinee_record_expr
= Backend::struct_field_expression (match_scrutinee_expr, variant_index,
pattern.get_path ().get_locus ());
tree scrutinee_expr_qualifier_expr
= Backend::struct_field_expression (scrutinee_record_expr, 0,
= Backend::struct_field_expression (match_scrutinee_expr, 0,
pattern.get_path ().get_locus ());

check_expr
= Backend::comparison_expression (ComparisonOperator::EQUAL,
scrutinee_expr_qualifier_expr,
discrim_expr_node,
pattern.get_path ().get_locus ());

match_scrutinee_expr = scrutinee_record_expr;
// we are offsetting by + 1 here since the first field in the record
// is always the discriminator
tuple_field_index = 1;
}
else
{
variant = adt->get_variants ().at (0);
check_expr = boolean_true_node;
tuple_field_index = 0;
}

HIR::TupleStructItems &items = pattern.get_items ();
Expand All @@ -367,10 +351,20 @@ CompilePatternCheckExpr::visit (HIR::TupleStructPattern &pattern)
rust_assert (items_no_range.get_patterns ().size ()
== variant->num_fields ());

size_t tuple_field_index = 0;
for (auto &pattern : items_no_range.get_patterns ())
{
// find payload union field of scrutinee
tree payload_ref
= Backend::struct_field_expression (match_scrutinee_expr, 1,
pattern->get_locus ());

tree variant_ref
= Backend::struct_field_expression (payload_ref, variant_index,
pattern->get_locus ());

tree field_expr
= Backend::struct_field_expression (match_scrutinee_expr,
= Backend::struct_field_expression (variant_ref,
tuple_field_index++,
pattern->get_locus ());

Expand Down Expand Up @@ -470,13 +464,15 @@ CompilePatternBindings::visit (HIR::TupleStructPattern &pattern)

if (adt->is_enum ())
{
// we are offsetting by + 1 here since the first field in the record
// is always the discriminator
size_t tuple_field_index = 1;
size_t tuple_field_index = 0;
for (auto &pattern : items_no_range.get_patterns ())
{
tree payload_accessor_union
= Backend::struct_field_expression (match_scrutinee_expr, 1,
pattern->get_locus ());

tree variant_accessor
= Backend::struct_field_expression (match_scrutinee_expr,
= Backend::struct_field_expression (payload_accessor_union,
variant_index,
pattern->get_locus ());

Expand Down Expand Up @@ -569,16 +565,18 @@ CompilePatternBindings::visit (HIR::StructPattern &pattern)
tree binding = error_mark_node;
if (adt->is_enum ())
{
tree payload_accessor_union
= Backend::struct_field_expression (match_scrutinee_expr, 1,
ident.get_locus ());

tree variant_accessor
= Backend::struct_field_expression (match_scrutinee_expr,
= Backend::struct_field_expression (payload_accessor_union,
variant_index,
ident.get_locus ());

// we are offsetting by + 1 here since the first field in the
// record is always the discriminator
binding = Backend::struct_field_expression (variant_accessor,
offs + 1,
ident.get_locus ());
binding
= Backend::struct_field_expression (variant_accessor, offs,
ident.get_locus ());
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions gcc/rust/backend/rust-compile-resolve-path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ ResolvePathRef::attempt_constructor_expression_lookup (
tree folded_discrim_expr = fold_expr (discrim_expr_node);
tree qualifier = folded_discrim_expr;

return Backend::constructor_expression (compiled_adt_type, true, {qualifier},
union_disriminator, expr_locus);
// false for is enum but this is an enum but we have a new layout
return Backend::constructor_expression (compiled_adt_type, false, {qualifier},
-1, expr_locus);
}

tree
Expand Down
62 changes: 48 additions & 14 deletions gcc/rust/backend/rust-compile-type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,39 @@ TyTyResolveCompile::visit (const TyTy::ADTType &type)
// Ada, qual_union_types might still work for this but I am not 100% sure.
// I ran into some issues lets reuse our normal union and ask Ada people
// about it.
//
// I think the above is actually wrong and it should actually be this
//
// struct {
// int RUST$ENUM$DISR; // take into account the repr for this TODO
// union {
// // Variant A
// struct {
// // No additional fields
// } A;

// // Variant B
// struct {
// // No additional fields
// } B;

// // Variant C
// struct {
// char c;
// } C;

// // Variant D
// struct {
// int64_t x;
// int64_t y;
// } D;
// } payload; // The union of all variant data
// };

std::vector<tree> variant_records;
for (auto &variant : type.get_variants ())
{
std::vector<Backend::typed_identifier> fields;

// add in the qualifier field for the variant
tree enumeral_type
= TyTyResolveCompile::get_implicit_enumeral_node_type ();
Backend::typed_identifier f (RUST_ENUM_DISR_FIELD_NAME, enumeral_type,
ctx->get_mappings ().lookup_location (
variant->get_id ()));
fields.push_back (std::move (f));

// compile the rest of the fields
for (size_t i = 0; i < variant->num_fields (); i++)
{
const TyTy::StructFieldType *field
Expand All @@ -336,9 +354,6 @@ TyTyResolveCompile::visit (const TyTy::ADTType &type)
= Backend::named_type (variant->get_ident ().path.get (),
variant_record, variant->get_ident ().locus);

// set the qualifier to be a builtin
DECL_ARTIFICIAL (TYPE_FIELDS (variant_record)) = 1;

// add them to the list
variant_records.push_back (named_variant_record);
}
Expand All @@ -359,8 +374,27 @@ TyTyResolveCompile::visit (const TyTy::ADTType &type)
enum_fields.push_back (std::move (f));
}

//
location_t locus = ctx->get_mappings ().lookup_location (type.get_ref ());

// finally make the union or the enum
type_record = Backend::union_type (enum_fields, false);
tree variants_union = Backend::union_type (enum_fields, false);
layout_type (variants_union);
tree named_union_record
= Backend::named_type ("payload", variants_union, locus);

// create the overall struct
tree enumeral_type
= TyTyResolveCompile::get_implicit_enumeral_node_type ();
Backend::typed_identifier discrim (RUST_ENUM_DISR_FIELD_NAME,
enumeral_type, locus);
Backend::typed_identifier variants_union_field ("payload",
named_union_record,
locus);

std::vector<Backend::typed_identifier> fields
= {discrim, variants_union_field};
type_record = Backend::struct_type (fields, false);
}

// Handle repr options
Expand Down
2 changes: 1 addition & 1 deletion gcc/rust/rust-gcc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ constructor_expression (tree type_tree, bool is_variant,
if (!TREE_CONSTANT (elt->value))
is_constant = false;
}
gcc_assert (field == NULL_TREE);
// gcc_assert (field == NULL_TREE);
}
}

Expand Down
Loading