Skip to content

Commit

Permalink
[Python] Render enums as Python IntEnum (#8145)
Browse files Browse the repository at this point in the history
This allows enums to be type check with mypy.
They will still behave like ints ->
> IntEnum is the same as Enum,
> but its members are also integers and can be used anywhere
> that an integer can be used.
> If any integer operation is performed with an IntEnum member,
> the resulting value loses its enumeration status.
https://docs.python.org/3/library/enum.html#enum.IntEnum

Only if the --python-typing flag is set.
  • Loading branch information
fliiiix authored Jun 3, 2024
1 parent 6ede1cc commit dafd2f1
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 22 deletions.
13 changes: 11 additions & 2 deletions src/idl_gen_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,12 +591,21 @@ class PythonStubGenerator {
void GenerateEnumStub(std::stringstream &stub, const EnumDef *enum_def,
Imports *imports) const {
stub << "class " << namer_.Type(*enum_def);
if (version_.major != 3) stub << "(object)";

if (version_.major == 3){
imports->Import("enum", "IntEnum");
stub << "(IntEnum)";
}
else {
stub << "(object)";
}

stub << ":\n";
for (const EnumVal *val : enum_def->Vals()) {
stub << " " << namer_.Variant(*val) << ": "
<< ScalarType(enum_def->underlying_type.base_type) << "\n";
}

if (parser_.opts.generate_object_based_api & enum_def->is_union) {
imports->Import("flatbuffers", "table");
stub << "def " << namer_.Function(*enum_def)
Expand Down Expand Up @@ -2432,7 +2441,7 @@ class PythonGenerator : public BaseGenerator {
auto field_type = namer_.ObjectType(*ev.union_type.struct_def);

code +=
GenIndents(1) + "if unionType == " + union_type + "()." + variant + ":";
GenIndents(1) + "if unionType == " + union_type + "." + variant + ":";
if (parser_.opts.include_dependence_headers) {
auto package_reference = GenPackageReference(ev.union_type);
code += GenIndents(2) + "import " + package_reference;
Expand Down
6 changes: 3 additions & 3 deletions tests/MyGame/Example/Any.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def AnyCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Any().Monster:
if unionType == Any.Monster:
import MyGame.Example.Monster
return MyGame.Example.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().TestSimpleTableWithEnum:
if unionType == Any.TestSimpleTableWithEnum:
import MyGame.Example.TestSimpleTableWithEnum
return MyGame.Example.TestSimpleTableWithEnum.TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().MyGame_Example2_Monster:
if unionType == Any.MyGame_Example2_Monster:
import MyGame.Example2.Monster
return MyGame.Example2.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None
6 changes: 3 additions & 3 deletions tests/MyGame/Example/AnyAmbiguousAliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def AnyAmbiguousAliasesCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == AnyAmbiguousAliases().M1:
if unionType == AnyAmbiguousAliases.M1:
import MyGame.Example.Monster
return MyGame.Example.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyAmbiguousAliases().M2:
if unionType == AnyAmbiguousAliases.M2:
import MyGame.Example.Monster
return MyGame.Example.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyAmbiguousAliases().M3:
if unionType == AnyAmbiguousAliases.M3:
import MyGame.Example.Monster
return MyGame.Example.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None
6 changes: 3 additions & 3 deletions tests/MyGame/Example/AnyUniqueAliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def AnyUniqueAliasesCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == AnyUniqueAliases().M:
if unionType == AnyUniqueAliases.M:
import MyGame.Example.Monster
return MyGame.Example.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyUniqueAliases().TS:
if unionType == AnyUniqueAliases.TS:
import MyGame.Example.TestSimpleTableWithEnum
return MyGame.Example.TestSimpleTableWithEnum.TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyUniqueAliases().M2:
if unionType == AnyUniqueAliases.M2:
import MyGame.Example2.Monster
return MyGame.Example2.Monster.MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None
4 changes: 2 additions & 2 deletions tests/MyGame/Example/NestedUnion/Any.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def AnyCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Any().Vec3:
if unionType == Any.Vec3:
import MyGame.Example.NestedUnion.Vec3
return MyGame.Example.NestedUnion.Vec3.Vec3T.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().TestSimpleTableWithEnum:
if unionType == Any.TestSimpleTableWithEnum:
import MyGame.Example.NestedUnion.TestSimpleTableWithEnum
return MyGame.Example.NestedUnion.TestSimpleTableWithEnum.TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
return None
18 changes: 9 additions & 9 deletions tests/monster_test_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def AnyCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Any().Monster:
if unionType == Any.Monster:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().TestSimpleTableWithEnum:
if unionType == Any.TestSimpleTableWithEnum:
return TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().MyGame_Example2_Monster:
if unionType == Any.MyGame_Example2_Monster:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -58,11 +58,11 @@ def AnyUniqueAliasesCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == AnyUniqueAliases().M:
if unionType == AnyUniqueAliases.M:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyUniqueAliases().TS:
if unionType == AnyUniqueAliases.TS:
return TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyUniqueAliases().M2:
if unionType == AnyUniqueAliases.M2:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -77,11 +77,11 @@ def AnyAmbiguousAliasesCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == AnyAmbiguousAliases().M1:
if unionType == AnyAmbiguousAliases.M1:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyAmbiguousAliases().M2:
if unionType == AnyAmbiguousAliases.M2:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
if unionType == AnyAmbiguousAliases().M3:
if unionType == AnyAmbiguousAliases.M3:
return MonsterT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand Down

0 comments on commit dafd2f1

Please sign in to comment.