Skip to content

enum refactoring #6158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/flatbuffers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ categories = ["encoding", "data-structures", "memory-management"]

[dependencies]
smallvec = "1.0"
thiserror = "1.0"
9 changes: 9 additions & 0 deletions rust/flatbuffers/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use std::fmt::{Debug, Display};

use thiserror::Error;

#[derive(Error, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub enum ConvertError<T: Debug + Display> {
#[error("unknown variant in buffer: {0}")]
UnknownVariant(T),
}
3 changes: 3 additions & 0 deletions rust/flatbuffers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
//!
//! At this time, to generate Rust code, you will need the latest `master` version of `flatc`, available from here: https://github.com/google/flatbuffers
//! (On OSX, you can install FlatBuffers from `HEAD` with the Homebrew package manager.)
extern crate thiserror;

mod builder;
mod endian_scalar;
mod error;
mod follow;
mod primitives;
mod push;
Expand All @@ -42,6 +44,7 @@ pub use builder::FlatBufferBuilder;
pub use endian_scalar::{
byte_swap_f32, byte_swap_f64, emplace_scalar, read_scalar, read_scalar_at, EndianScalar,
};
pub use error::ConvertError;
pub use follow::{Follow, FollowStart};
pub use primitives::*;
pub use push::Push;
Expand Down
225 changes: 191 additions & 34 deletions src/idl_gen_rust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ class RustGenerator : public BaseGenerator {
return Name(enum_def) + "::" + Name(enum_val);
}

static bool IsBitFlags(const EnumDef &enum_def) {
return enum_def.attributes.Lookup("bit_flags") != nullptr;
}

// Generate an enum declaration,
// an enum string lookup table,
// an enum match function,
Expand Down Expand Up @@ -552,12 +556,19 @@ class RustGenerator : public BaseGenerator {
code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv));
code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv));

// Generate enum constants, and impls for Follow, EndianScalar, and Push.
// Generate enum constants, and impls for TryFrom, Follow, EndianScalar, and Push.
code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants instead.\")]";
code_ += "pub const ENUM_MIN_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\";
code_ += "{{ENUM_MIN_BASE_VALUE}};";
code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants instead.\")]";
code_ += "pub const ENUM_MAX_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\";
code_ += "{{ENUM_MAX_BASE_VALUE}};";
code_ += "";
code_ += "impl {{ENUM_NAME}} {";
code_ += " pub const MIN: {{BASE_TYPE}} = {{ENUM_MIN_BASE_VALUE}};";
code_ += " pub const MAX: {{BASE_TYPE}} = {{ENUM_MAX_BASE_VALUE}};";
code_ += "}";
code_ += "";
code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {";
code_ += " type Inner = Self;";
code_ += " #[inline]";
Expand Down Expand Up @@ -592,9 +603,33 @@ class RustGenerator : public BaseGenerator {
code_ += "}";
code_ += "";

if (!IsBitFlags(enum_def)) {
code_ += "impl TryFrom<{{BASE_TYPE}}> for {{ENUM_NAME}} {";
code_ += " type Error = flatbuffers::ConvertError<{{BASE_TYPE}}>;";
code_ += "";
code_ += " #[inline]";
code_ += " fn try_from(value: {{BASE_TYPE}}) -> Result<Self, Self::Error> {";
code_ += " match value {";

for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
const auto &ev = **it;

code_.SetValue("KEY", Name(ev));
code_.SetValue("VALUE", enum_def.ToString(ev));
code_ += " {{VALUE}} => Ok({{ENUM_NAME}}::{{KEY}}),";
}

code_ += " _ => Err(Self::Error::UnknownVariant(value))",
code_ += " }";
code_ += " }";
code_ += "}";
code_ += "";
}

// Generate an array of all enumeration values.
auto num_fields = NumToString(enum_def.size());
code_ += "#[allow(non_camel_case_types)]";
code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants instead.\")]";
code_ += "pub const ENUM_VALUES_{{ENUM_NAME_CAPS}}: [{{ENUM_NAME}}; " +
num_fields + "] = [";
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
Expand All @@ -615,21 +650,21 @@ class RustGenerator : public BaseGenerator {
// "too sparse". Change at will.
static const uint64_t kMaxSparseness = 5;
if (range / static_cast<uint64_t>(enum_def.size()) < kMaxSparseness) {
code_ += "impl {{ENUM_NAME}} {";
code_ += " pub const NAMES: [&'static str; " + NumToString(range + 1) + "] = [";

EnumerateEnumNames(enum_def);

code_ += " ];";
code_ += "}";
code_ += "";
code_ += "#[allow(non_camel_case_types)]";
code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants instead.\")]";
code_ += "pub const ENUM_NAMES_{{ENUM_NAME_CAPS}}: [&str; " +
NumToString(range + 1) + "] = [";

auto val = enum_def.Vals().front();
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end();
++it) {
auto ev = *it;
for (auto k = enum_def.Distance(val, ev); k > 1; --k) {
code_ += " \"\",";
}
val = ev;
auto suffix = *it != enum_def.Vals().back() ? "," : "";
code_ += " \"" + Name(*ev) + "\"" + suffix;
}
EnumerateEnumNames(enum_def);

code_ += "];";
code_ += "";

Expand All @@ -644,7 +679,7 @@ class RustGenerator : public BaseGenerator {
}
code_ += ";";

code_ += " ENUM_NAMES_{{ENUM_NAME_CAPS}}[index as usize]";
code_ += " {{ENUM_NAME}}::NAMES[index as usize]";
code_ += "}";
code_ += "";
}
Expand All @@ -657,6 +692,20 @@ class RustGenerator : public BaseGenerator {
}
}

void EnumerateEnumNames(const EnumDef &enum_def) {
auto val = enum_def.Vals().front();
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end();
++it) {
auto ev = *it;
for (auto k = enum_def.Distance(val, ev); k > 1; --k) {
code_ += " \"\",";
}
val = ev;
auto suffix = *it != enum_def.Vals().back() ? "," : "";
code_ += " \"" + Name(*ev) + "\"" + suffix;
}
}

std::string GetFieldOffsetName(const FieldDef &field) {
return "VT_" + MakeUpper(Name(field));
}
Expand Down Expand Up @@ -907,7 +956,8 @@ class RustGenerator : public BaseGenerator {
}

std::string GenTableAccessorFuncReturnType(const FieldDef &field,
const std::string &lifetime) {
const std::string &lifetime,
bool unchecked) {
const Type &type = field.value.type;

switch (GetFullType(field.value.type)) {
Expand All @@ -927,11 +977,19 @@ class RustGenerator : public BaseGenerator {
return WrapInOptionIfNotRequired(typname + "<" + lifetime + ">",
field.required);
}
case ftEnumKey:
case ftUnionKey: {
const auto typname = WrapInNameSpace(*type.enum_def);
return field.optional ? "Option<" + typname + ">" : typname;
}
case ftEnumKey: {
auto typname = WrapInNameSpace(*type.enum_def);
if (!unchecked) {
const auto underlying_typname = GetEnumTypeForDecl(type.enum_def->underlying_type);
typname = "Result<" + typname + ", flatbuffers::ConvertError<" + underlying_typname + ">>";
}

return field.optional ? "Option<" + typname + ">" : typname;
}

case ftUnionValue: {
return WrapInOptionIfNotRequired("flatbuffers::Table<" + lifetime + ">",
Expand Down Expand Up @@ -987,14 +1045,49 @@ class RustGenerator : public BaseGenerator {
return "INVALID_CODE_GENERATION"; // for return analysis
}

std::string GenTableUncheckedAccessor(const FieldDef &field,
const std::string &offset_prefix) {
const std::string offset_name =
offset_prefix + "::" + GetFieldOffsetName(field);
const Type &type = field.value.type;

if (GetFullType(type) != ftEnumKey) {
FLATBUFFERS_ASSERT(false && "unchecked access is supported for enums only");
return "INVALID_CODE_GENERATION"; // for return analysis
}

const auto underlying_typename = GetEnumTypeForDecl(type.enum_def->underlying_type);
if (field.optional) {
return "self._tab.get::<" + underlying_typename + ">(" + offset_name + ", None)" +
".map(|value| std::mem::transmute(value))";
} else {
return "self._tab.get::<" + underlying_typename + ">(" + offset_name + ", Some(" +
field.value.constant + ")).map(|value| std::mem::transmute(value)).unwrap()";
}
}

std::string GenRawEnumAccessor(const std::string &underlying_type,
const FieldDef &field,
const std::string &offset_prefix) {
const auto offset_name =
offset_prefix + "::" + GetFieldOffsetName(field);

if (field.optional) {
return "self._tab.get::<" + underlying_type + ">(" + offset_name + ", None)";
}

return "self._tab.get::<" + underlying_type + ">(" + offset_name + ", Some(" +
field.value.constant + ")).unwrap()";
}

std::string GenTableAccessorFuncBody(const FieldDef &field,
const std::string &lifetime,
const std::string &offset_prefix) {
const std::string offset_name =
offset_prefix + "::" + GetFieldOffsetName(field);
const Type &type = field.value.type;

switch (GetFullType(field.value.type)) {
switch (GetFullType(type)) {
case ftInteger:
case ftFloat:
case ftBool: {
Expand Down Expand Up @@ -1027,16 +1120,20 @@ class RustGenerator : public BaseGenerator {
lifetime + ">>>(" + offset_name + ", None)",
field.required);
}
case ftUnionKey:
case ftUnionKey: {
return GenerateEnumOrUnionAccess(field, offset_name, type);
}
case ftEnumKey: {
const auto underlying_typname = GetTypeBasic(type); //<- never used
const auto typname = WrapInNameSpace(*type.enum_def);
const auto default_value = GetDefaultScalarValue(field);
if (IsBitFlags(*type.enum_def)) {
return GenerateEnumOrUnionAccess(field, offset_name, type);
}

const auto underlying_typname = GetEnumTypeForDecl(type.enum_def->underlying_type);
if (field.optional) {
return "self._tab.get::<" + typname + ">(" + offset_name + ", None)";
return "self._tab.get::<" + underlying_typname + ">(" + offset_name + ", None).map(|value| value.try_into())";
} else {
return "self._tab.get::<" + typname + ">(" + offset_name + ", Some(" +
default_value + ")).unwrap()";
return "self._tab.get::<" + underlying_typname + ">(" + offset_name + ", Some(" +
field.value.constant + ")).map(|value| value.try_into()).unwrap()";
}
}
case ftString: {
Expand Down Expand Up @@ -1100,6 +1197,20 @@ class RustGenerator : public BaseGenerator {
return "INVALID_CODE_GENERATION"; // for return analysis
}

std::string GenerateEnumOrUnionAccess(const FieldDef &field,
const std::string &offset_name,
const Type &type) {
const auto underlying_typname = GetTypeBasic(type); //<- never used
const auto typname = WrapInNameSpace(*type.enum_def);
const auto default_value = GetDefaultScalarValue(field);
if (field.optional) {
return "self._tab.get::<" + typname + ">(" + offset_name + ", None)";
} else {
return "self._tab.get::<" + typname + ">(" + offset_name + ", Some(" +
default_value + ")).unwrap()";
}
}

bool TableFieldReturnsOption(const FieldDef &field) {
if (field.optional) return true;
switch (GetFullType(field.value.type)) {
Expand Down Expand Up @@ -1241,17 +1352,57 @@ class RustGenerator : public BaseGenerator {
continue;
}

code_.SetValue("FIELD_NAME", Name(field));
code_.SetValue("RETURN_TYPE",
GenTableAccessorFuncReturnType(field, "'a"));
code_.SetValue("FUNC_BODY",
GenTableAccessorFuncBody(field, "'a", offset_prefix));
if (GetFullType(field.value.type) == ftEnumKey && !IsBitFlags(*field.value.type.enum_def)) {
code_.SetValue("FIELD_NAME", Name(field));
code_.SetValue("RETURN_TYPE",
GenTableAccessorFuncReturnType(field, "'a", false));
code_.SetValue("FUNC_BODY",
GenTableAccessorFuncBody(field, "'a", offset_prefix));

GenComment(field.doc_comment, " ");
code_ += " #[inline]";
code_ += " pub fn {{FIELD_NAME}}(&self) -> {{RETURN_TYPE}} {";
code_ += " {{FUNC_BODY}}";
code_ += " }";
GenComment(field.doc_comment, " ");
code_ += " #[inline]";
code_ += " pub fn {{FIELD_NAME}}(&self) -> {{RETURN_TYPE}} {";
code_ += " {{FUNC_BODY}}";
code_ += " }";

code_.SetValue("RETURN_TYPE",
GenTableAccessorFuncReturnType(field, "'a", true));
code_.SetValue("FUNC_BODY",
GenTableUncheckedAccessor(field, offset_prefix));

GenComment(field.doc_comment, " ");
code_ += " #[inline]";
code_ += " pub unsafe fn {{FIELD_NAME}}_unchecked(&self) -> {{RETURN_TYPE}} {";
code_ += " {{FUNC_BODY}}";
code_ += " }";

const auto underlying_type = GetEnumTypeForDecl(field.value.type.enum_def->underlying_type);

code_.SetValue("RETURN_TYPE", WrapInOptionIfNotRequired(underlying_type, !field.optional));
code_.SetValue("FUNC_BODY", GenRawEnumAccessor(underlying_type, field, offset_prefix));

GenComment(field.doc_comment, " ");
code_ += " #[inline]";
code_ += " pub fn {{FIELD_NAME}}_raw(&self) -> {{RETURN_TYPE}} {";
code_ += " {{FUNC_BODY}}";
code_ += " }";
} else {
const auto bit_flags = GetFullType(field.value.type) == ftEnumKey && IsBitFlags(*field.value.type.enum_def);

code_.SetValue("FIELD_NAME", Name(field));

// treat bitflags as unchecked for now
code_.SetValue("RETURN_TYPE",
GenTableAccessorFuncReturnType(field, "'a", bit_flags));
code_.SetValue("FUNC_BODY",
GenTableAccessorFuncBody(field, "'a", offset_prefix));

GenComment(field.doc_comment, " ");
code_ += " #[inline]";
code_ += " pub fn {{FIELD_NAME}}(&self) -> {{RETURN_TYPE}} {";
code_ += " {{FUNC_BODY}}";
code_ += " }";
}

// Generate a comparison function for this field if it is a key.
if (field.key) { GenKeyFieldMethods(field); }
Expand Down Expand Up @@ -1480,7 +1631,7 @@ class RustGenerator : public BaseGenerator {
void GenKeyFieldMethods(const FieldDef &field) {
FLATBUFFERS_ASSERT(field.key);

code_.SetValue("KEY_TYPE", GenTableAccessorFuncReturnType(field, ""));
code_.SetValue("KEY_TYPE", GenTableAccessorFuncReturnType(field, "", false));

code_ += " #[inline]";
code_ +=
Expand Down Expand Up @@ -1786,6 +1937,12 @@ class RustGenerator : public BaseGenerator {

code_ += indent + "use std::mem;";
code_ += indent + "use std::cmp::Ordering;";

if (!parser_.enums_.vec.empty()) {
code_ += indent + "use std::convert::TryFrom;";
code_ += indent + "use std::convert::TryInto;";
}

code_ += "";
code_ += indent + "extern crate flatbuffers;";
code_ += indent + "use self::flatbuffers::EndianScalar;";
Expand Down
2 changes: 2 additions & 0 deletions tests/include_test/include_test1_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
use crate::include_test2_generated::*;
use std::mem;
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::convert::TryInto;

extern crate flatbuffers;
use self::flatbuffers::EndianScalar;
Expand Down
Loading