Skip to content

Commit

Permalink
name mangling and closing on function's odds and ends
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent ca4ac94 commit 097ec2e
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 10 deletions.
32 changes: 31 additions & 1 deletion apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ Halide::PrefetchBoundStrategy Deserializer::deserialize_prefetch_bound_strategy(
}
}

Halide::NameMangling Deserializer::deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling) {
switch (name_mangling) {
case Halide::Serialize::NameMangling::NameMangling_Default:
return Halide::NameMangling::Default;
case Halide::Serialize::NameMangling::NameMangling_C:
return Halide::NameMangling::C;
case Halide::Serialize::NameMangling::NameMangling_CPlusPlus:
return Halide::NameMangling::CPlusPlus;
default:
std::cerr << "unknown name mangling " << name_mangling << "\n";
exit(1);
}
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
Expand Down Expand Up @@ -200,8 +214,24 @@ Halide::Internal::Function Deserializer::deserialize_function(const Halide::Seri
updates.push_back(deserialize_definition(update));
}

std::string debug_file = deserialize_string(function->debug_file());
std::string extern_function_name = deserialize_string(function->extern_function_name());
auto name_mangling = deserialize_name_mangling(function->extern_mangling());
auto extern_function_device_api = deserialize_device_api(function->extern_function_device_api());
auto extern_proxy_expr = deserialize_expr(function->extern_proxy_expr_type(), function->extern_proxy_expr());
bool trace_loads = function->trace_loads(), trace_stores = function->trace_stores(), trace_realizations = function->trace_realizations();
std::vector<std::string> trace_tags;
trace_tags.reserve(function->trace_tags()->size());
for (const auto &tag : *function->trace_tags()) {
trace_tags.push_back(deserialize_string(tag));
}
bool frozen = function->frozen();

return Halide::Internal::Function(name, origin_name, output_types, required_types,
required_dim, args, func_schedule, init_def, updates);
required_dim, args, func_schedule, init_def, updates,
debug_file, extern_function_name, name_mangling,
extern_function_device_api, extern_proxy_expr,
trace_loads, trace_stores, trace_realizations, trace_tags, frozen);
}

Halide::Internal::Stmt Deserializer::deserialize_stmt(uint8_t type_code, const void *stmt) {
Expand Down
2 changes: 2 additions & 0 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Deserializer {

Halide::PrefetchBoundStrategy deserialize_prefetch_bound_strategy(const Halide::Serialize::PrefetchBoundStrategy prefetch_bound_strategy);

Halide::NameMangling deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling);

std::string deserialize_string(const flatbuffers::String *str);

Halide::Type deserialize_type(const Halide::Serialize::Type *type);
Expand Down
34 changes: 33 additions & 1 deletion apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ Halide::Serialize::PrefetchBoundStrategy Serializer::serialize_prefetch_bound_st
}
}

Halide::Serialize::NameMangling Serializer::serialize_name_mangling(const Halide::NameMangling &name_mangling) {
switch (name_mangling) {
case Halide::NameMangling::Default:
return Halide::Serialize::NameMangling::NameMangling_Default;
case Halide::NameMangling::C:
return Halide::Serialize::NameMangling::NameMangling_C;
case Halide::NameMangling::CPlusPlus:
return Halide::Serialize::NameMangling::NameMangling_CPlusPlus;
default:
std::cerr << "Unsupported name mangling\n";
exit(1);
}
}

flatbuffers::Offset<flatbuffers::String> Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) {
return builder.CreateString(str);
}
Expand Down Expand Up @@ -591,7 +605,25 @@ flatbuffers::Offset<Halide::Serialize::Func> Serializer::serialize_function(flat
for (const auto &update : function.updates()) {
updates_serialized.push_back(serialize_definition(builder, update));
}
auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized));

auto debug_file_serialized = serialize_string(builder, function.debug_file());
auto extern_function_name_serialized = serialize_string(builder, function.extern_function_name());
auto extern_mangling_serialized = serialize_name_mangling(function.extern_definition_name_mangling());
auto extern_function_device_api_serialized = serialize_device_api(function.extern_function_device_api());
auto extern_proxy_expr_serialized = serialize_expr(builder, function.extern_definition_proxy_expr());
bool trace_loads = function.is_tracing_loads();
bool trace_stores = function.is_tracing_stores();
bool trace_realizations = function.is_tracing_realizations();
std::vector<flatbuffers::Offset<flatbuffers::String>> trace_tags_serialized;
trace_tags_serialized.reserve(function.get_trace_tags().size());
for (const auto& tag: function.get_trace_tags()) {
trace_tags_serialized.push_back(serialize_string(builder, tag));
}
bool frozen = function.frozen();
auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim,
args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized), debug_file_serialized,
extern_function_name_serialized, extern_mangling_serialized, extern_function_device_api_serialized, extern_proxy_expr_serialized.first,
extern_proxy_expr_serialized.second, trace_loads, trace_stores, trace_realizations, builder.CreateVector(trace_tags_serialized), frozen);
return func;
}

Expand Down
2 changes: 2 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Serializer {

Halide::Serialize::PrefetchBoundStrategy serialize_prefetch_bound_strategy(const Halide::PrefetchBoundStrategy &prefetch_bound_strategy);

Halide::Serialize::NameMangling serialize_name_mangling(const Halide::NameMangling &name_mangling);

flatbuffers::Offset<flatbuffers::String> serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);

flatbuffers::Offset<Halide::Serialize::Type> serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type);
Expand Down
21 changes: 16 additions & 5 deletions apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,12 @@ table FuncMapping {
index: int32;
}

enum NameMangling: ubyte {
Default,
C,
CPlusPlus,
}

// Halide::internal::Function
table Func {
name: string;
Expand All @@ -524,13 +530,18 @@ table Func {
func_schedule: FuncSchedule;
init_def: Definition;
updates: [Definition];
// debug_file: string;
debug_file: string;
// output_buffers: [Parameter];
// extern_arguments: [ExternFuncArgument];
// extern_function_name: string;
// extern_mangling: NameMangling;
// extern_device_api: DeviceAPI;
// extern_proxy_expr: Expr;
extern_function_name: string;
extern_mangling: NameMangling;
extern_function_device_api: DeviceAPI;
extern_proxy_expr: Expr;
trace_loads: bool = false;
trace_stores: bool = false;
trace_realizations: bool = false;
trace_tags: [string];
frozen: bool = false;
}

table Pipeline {
Expand Down
17 changes: 15 additions & 2 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,11 @@ Function::Function(const std::vector<Type> &required_types, int required_dims, c
}

Function::Function(const std::string &name, const std::string &origin_name, const std::vector<Halide::Type> &output_types,
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates) {
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates,
const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
const std::vector<std::string> &trace_tags, bool frozen) {
contents.strong = new FunctionGroup;
contents.strong->members.resize(1);
contents->name = name;
Expand All @@ -345,6 +348,16 @@ Function::Function(const std::string &name, const std::string &origin_name, cons
contents->func_schedule = func_schedule;
contents->init_def = init_def;
contents->updates = updates;
contents->debug_file = debug_file;
contents->extern_function_name = extern_function_name;
contents->extern_mangling = name_mangling;
contents->extern_function_device_api = device_api;
contents->extern_proxy_expr = extern_proxy_expr;
contents->trace_loads = trace_loads;
contents->trace_stores = trace_stores;
contents->trace_realizations = trace_realizations;
contents->trace_tags = trace_tags;
contents->frozen = frozen;
}

namespace {
Expand Down
5 changes: 4 additions & 1 deletion src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ class Function {
/** Construct a function from deserializing */
explicit Function(const std::string &name, const std::string &origin_name, const std::vector<Halide::Type> &output_types,
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates);
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates,
const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
const std::vector<std::string> &trace_tags, bool frozen);

/** Get a handle on the halide function contents that this Function
* represents. */
Expand Down

0 comments on commit 097ec2e

Please sign in to comment.