diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp index 075215516e0c..b297f038d7f8 100644 --- a/apps/serdes/Deserializer.cpp +++ b/apps/serdes/Deserializer.cpp @@ -144,6 +144,20 @@ Halide::PrefetchBoundStrategy Deserializer::deserialize_prefetch_bound_strategy( } } +Halide::NameMangling Deserializer::deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling) { + switch (name_mangling) { + case Halide::Serialize::NameMangling::NameMangling_Default: + return Halide::NameMangling::Default; + case Halide::Serialize::NameMangling::NameMangling_C: + return Halide::NameMangling::C; + case Halide::Serialize::NameMangling::NameMangling_CPlusPlus: + return Halide::NameMangling::CPlusPlus; + default: + std::cerr << "unknown name mangling " << name_mangling << "\n"; + exit(1); + } +} + Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) { using Halide::Serialize::TypeCode; int bits = type->bits(); @@ -200,8 +214,24 @@ Halide::Internal::Function Deserializer::deserialize_function(const Halide::Seri updates.push_back(deserialize_definition(update)); } + std::string debug_file = deserialize_string(function->debug_file()); + std::string extern_function_name = deserialize_string(function->extern_function_name()); + auto name_mangling = deserialize_name_mangling(function->extern_mangling()); + auto extern_function_device_api = deserialize_device_api(function->extern_function_device_api()); + auto extern_proxy_expr = deserialize_expr(function->extern_proxy_expr_type(), function->extern_proxy_expr()); + bool trace_loads = function->trace_loads(), trace_stores = function->trace_stores(), trace_realizations = function->trace_realizations(); + std::vector trace_tags; + trace_tags.reserve(function->trace_tags()->size()); + for (const auto &tag : *function->trace_tags()) { + trace_tags.push_back(deserialize_string(tag)); + } + bool frozen = function->frozen(); + return Halide::Internal::Function(name, origin_name, output_types, required_types, - required_dim, args, func_schedule, init_def, updates); + required_dim, args, func_schedule, init_def, updates, + debug_file, extern_function_name, name_mangling, + extern_function_device_api, extern_proxy_expr, + trace_loads, trace_stores, trace_realizations, trace_tags, frozen); } Halide::Internal::Stmt Deserializer::deserialize_stmt(uint8_t type_code, const void *stmt) { diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h index 27c0a9ce28f5..2d37d7ba15c7 100644 --- a/apps/serdes/Deserializer.h +++ b/apps/serdes/Deserializer.h @@ -30,6 +30,8 @@ class Deserializer { Halide::PrefetchBoundStrategy deserialize_prefetch_bound_strategy(const Halide::Serialize::PrefetchBoundStrategy prefetch_bound_strategy); + Halide::NameMangling deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling); + std::string deserialize_string(const flatbuffers::String *str); Halide::Type deserialize_type(const Halide::Serialize::Type *type); diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp index ccccdc5c7e4e..d501233f2c58 100644 --- a/apps/serdes/Serializer.cpp +++ b/apps/serdes/Serializer.cpp @@ -143,6 +143,20 @@ Halide::Serialize::PrefetchBoundStrategy Serializer::serialize_prefetch_bound_st } } +Halide::Serialize::NameMangling Serializer::serialize_name_mangling(const Halide::NameMangling &name_mangling) { + switch (name_mangling) { + case Halide::NameMangling::Default: + return Halide::Serialize::NameMangling::NameMangling_Default; + case Halide::NameMangling::C: + return Halide::Serialize::NameMangling::NameMangling_C; + case Halide::NameMangling::CPlusPlus: + return Halide::Serialize::NameMangling::NameMangling_CPlusPlus; + default: + std::cerr << "Unsupported name mangling\n"; + exit(1); + } +} + flatbuffers::Offset Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) { return builder.CreateString(str); } @@ -591,7 +605,25 @@ flatbuffers::Offset Serializer::serialize_function(flat for (const auto &update : function.updates()) { updates_serialized.push_back(serialize_definition(builder, update)); } - auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized)); + + auto debug_file_serialized = serialize_string(builder, function.debug_file()); + auto extern_function_name_serialized = serialize_string(builder, function.extern_function_name()); + auto extern_mangling_serialized = serialize_name_mangling(function.extern_definition_name_mangling()); + auto extern_function_device_api_serialized = serialize_device_api(function.extern_function_device_api()); + auto extern_proxy_expr_serialized = serialize_expr(builder, function.extern_definition_proxy_expr()); + bool trace_loads = function.is_tracing_loads(); + bool trace_stores = function.is_tracing_stores(); + bool trace_realizations = function.is_tracing_realizations(); + std::vector> trace_tags_serialized; + trace_tags_serialized.reserve(function.get_trace_tags().size()); + for (const auto& tag: function.get_trace_tags()) { + trace_tags_serialized.push_back(serialize_string(builder, tag)); + } + bool frozen = function.frozen(); + auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, + args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized), debug_file_serialized, + extern_function_name_serialized, extern_mangling_serialized, extern_function_device_api_serialized, extern_proxy_expr_serialized.first, + extern_proxy_expr_serialized.second, trace_loads, trace_stores, trace_realizations, builder.CreateVector(trace_tags_serialized), frozen); return func; } diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h index 290075a05515..496b895f6120 100644 --- a/apps/serdes/Serializer.h +++ b/apps/serdes/Serializer.h @@ -31,6 +31,8 @@ class Serializer { Halide::Serialize::PrefetchBoundStrategy serialize_prefetch_bound_strategy(const Halide::PrefetchBoundStrategy &prefetch_bound_strategy); + Halide::Serialize::NameMangling serialize_name_mangling(const Halide::NameMangling &name_mangling); + flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str); flatbuffers::Offset serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type); diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs index 0d2836ab369d..4b298a4bd566 100644 --- a/apps/serdes/halide_ir.fbs +++ b/apps/serdes/halide_ir.fbs @@ -513,6 +513,12 @@ table FuncMapping { index: int32; } +enum NameMangling: ubyte { + Default, + C, + CPlusPlus, +} + // Halide::internal::Function table Func { name: string; @@ -524,13 +530,18 @@ table Func { func_schedule: FuncSchedule; init_def: Definition; updates: [Definition]; - // debug_file: string; + debug_file: string; // output_buffers: [Parameter]; // extern_arguments: [ExternFuncArgument]; - // extern_function_name: string; - // extern_mangling: NameMangling; - // extern_device_api: DeviceAPI; - // extern_proxy_expr: Expr; + extern_function_name: string; + extern_mangling: NameMangling; + extern_function_device_api: DeviceAPI; + extern_proxy_expr: Expr; + trace_loads: bool = false; + trace_stores: bool = false; + trace_realizations: bool = false; + trace_tags: [string]; + frozen: bool = false; } table Pipeline { diff --git a/src/Function.cpp b/src/Function.cpp index 5d4eecc76130..15dbb8fa2ce6 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -332,8 +332,11 @@ Function::Function(const std::vector &required_types, int required_dims, c } Function::Function(const std::string &name, const std::string &origin_name, const std::vector &output_types, - const std::vector &required_types, int required_dims, const std::vector &args, - const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates) { + const std::vector &required_types, int required_dims, const std::vector &args, + const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates, + const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling, + const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations, + const std::vector &trace_tags, bool frozen) { contents.strong = new FunctionGroup; contents.strong->members.resize(1); contents->name = name; @@ -345,6 +348,16 @@ Function::Function(const std::string &name, const std::string &origin_name, cons contents->func_schedule = func_schedule; contents->init_def = init_def; contents->updates = updates; + contents->debug_file = debug_file; + contents->extern_function_name = extern_function_name; + contents->extern_mangling = name_mangling; + contents->extern_function_device_api = device_api; + contents->extern_proxy_expr = extern_proxy_expr; + contents->trace_loads = trace_loads; + contents->trace_stores = trace_stores; + contents->trace_realizations = trace_realizations; + contents->trace_tags = trace_tags; + contents->frozen = frozen; } namespace { diff --git a/src/Function.h b/src/Function.h index 11e249e90d2c..b016fea0249d 100644 --- a/src/Function.h +++ b/src/Function.h @@ -71,7 +71,10 @@ class Function { /** Construct a function from deserializing */ explicit Function(const std::string &name, const std::string &origin_name, const std::vector &output_types, const std::vector &required_types, int required_dims, const std::vector &args, - const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates); + const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates, + const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling, + const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations, + const std::vector &trace_tags, bool frozen); /** Get a handle on the halide function contents that this Function * represents. */