diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp index 434b524c9387..4228316b6311 100644 --- a/apps/serdes/Deserializer.cpp +++ b/apps/serdes/Deserializer.cpp @@ -2,7 +2,6 @@ #include #include -using Halide::Serialize::TypeCode; std::string Deserializer::deserialize_string(const flatbuffers::String *str) { return str->str(); @@ -35,7 +34,7 @@ Halide::MemoryType Deserializer::deserialize_memory_type(const Halide::Serialize } Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) { - // bits + using Halide::Serialize::TypeCode; int bits = type->bits(); int lanes = type->lanes(); TypeCode code_deserialized = type->code(); @@ -622,6 +621,22 @@ Halide::Internal::Definition Deserializer::deserialize_definition(const Halide:: return Halide::Internal::Definition(is_init, predicate, args, values, Halide::Internal::StageSchedule(), specializations, source_location); } +// TODO: will need to serialize a reverse table of map to +// later reconstruct a map of find out which function ptrs to use here +std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) { + return std::map(); +} + +std::map Deserializer::deserialize_func_mappings(const flatbuffers::Vector> *func_mappings) { + std::map result; + for (const auto &func_mapping : *func_mappings) { + auto name = deserialize_string(func_mapping->name()); + auto index = func_mapping->index(); + result[name] = index; + } + return result; +} + Halide::Pipeline Deserializer::deserialize(const std::string &filename) { // unpack binary file std::ifstream in(filename, std::ios::binary | std::ios::in); @@ -637,6 +652,9 @@ Halide::Pipeline Deserializer::deserialize(const std::string &filename) { in.read(data.data(), size); in.close(); + this->func_mappings_str2idx = deserialize_func_mappings(Halide::Serialize::GetPipeline(data.data())->func_mappings()); + this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings(); + const auto *pipeline_obj = Halide::Serialize::GetPipeline(data.data()); const auto *func_objs = pipeline_obj->outputs(); std::vector funcs; diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h index b44a75a962ea..e1a6ae55410f 100644 --- a/apps/serdes/Deserializer.h +++ b/apps/serdes/Deserializer.h @@ -14,6 +14,9 @@ class Deserializer { Halide::Pipeline deserialize(const std::string &filename); private: + std::map func_mappings_str2idx; + std::map func_mappings_idx2ptr; + // helper functions to deserialize each type of object Halide::MemoryType deserialize_memory_type(const Halide::Serialize::MemoryType memory_type); @@ -42,6 +45,10 @@ class Deserializer { Halide::Internal::Specialization deserialize_specialization(const Halide::Serialize::Specialization *specialization); Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition); + + std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs); + + std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings); }; #endif diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp index 27fcdf641259..b177476dd07a 100644 --- a/apps/serdes/Serializer.cpp +++ b/apps/serdes/Serializer.cpp @@ -777,6 +777,7 @@ flatbuffers::Offset Serializer::serialize_func_ for (const auto &estimate : func_schedule.estimates()) { estimates_serialized.push_back(serialize_bound(builder, estimate)); } + auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers()); // TODO: make this a func Halide::Serialize::MemoryType memory_type = Halide::Serialize::MemoryType::MemoryType_Auto; switch (func_schedule.memory_type()) { @@ -822,7 +823,7 @@ flatbuffers::Offset Serializer::serialize_func_ auto memoize_eviction_key = func_schedule.memoize_eviction_key(); auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key); return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized), - builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second); + builder.CreateVector(estimates_serialized), builder.CreateVector(wrappers_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second); } flatbuffers::Offset Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) { @@ -862,10 +863,34 @@ flatbuffers::Offset Serializer::serialize_definit builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized); } +std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) { + std::vector> wrapper_refs_serialized; + for (const auto& it : wrappers) { + std::string name = it.first; + const Halide::Internal::FunctionPtr& func_ptr = it.second; + uint64_t func_address = static_cast(reinterpret_cast(func_ptr.get())); + auto name_serialized = serialize_string(builder, name); + wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_address)); + } + return wrapper_refs_serialized; +} + +std::vector> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings) { + std::vector> func_mappings_serialized; + for (const auto& it : func_mappings) { + std::string name = it.first; + int32_t index = it.second; + auto name_serialized = serialize_string(builder, name); + func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index)); + } + return func_mappings_serialized; +} + void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) { std::cout << "Serializing a pipeline into " << filename << "\n"; flatbuffers::FlatBufferBuilder builder(1024); std::map env; + std::map func_mappings; // extract the DAG, unwarp function from Funcs for (const Halide::Func &func : pipeline.outputs()) { @@ -874,6 +899,15 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string & env.insert(more_funcs.begin(), more_funcs.end()); } + // construct the internal func mapping that will be used + // through serialization/deserialization to reassamble the DAG + { + int32_t i = 0; + for (const auto& it: env) { + func_mappings[it.first] = i++; + } + } + // serialize each func // TODO: this should be the correct way to serialize the whole DAG // a vector of all funcs + an extra map to map from name to index @@ -905,7 +939,9 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string & auto requirements_vector = builder.CreateVector(requirements_serialized); auto requirements_types_vector = builder.CreateVector(requirements_types); - auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector); + auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings); + + auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector, builder.CreateVector(func_mappings_serialized)); builder.Finish(pipeline_obj); // write the binary file diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h index 47b6cc18b74f..6cc0867488aa 100644 --- a/apps/serdes/Serializer.h +++ b/apps/serdes/Serializer.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "Halide.h" #include "halide_ir_generated.h" @@ -41,6 +42,10 @@ class Serializer { flatbuffers::Offset serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization); flatbuffers::Offset serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition); + + std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers); + + std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings); }; #endif diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs index 4cc3b7fd2d2f..f58826c743c1 100644 --- a/apps/serdes/halide_ir.fbs +++ b/apps/serdes/halide_ir.fbs @@ -422,13 +422,18 @@ table LoopLevel { locked: bool; } +table WrapperRef { + name: string; + func_address: uint64; +} + table FuncSchedule { store_level: LoopLevel; compute_level: LoopLevel; storage_dims: [StorageDim]; bounds: [Bound]; estimates: [Bound]; - // wrappers: [WrapperRef]; TODO: no WrapperRef yet + wrappers: [WrapperRef]; memory_type: MemoryType = Auto; memoized: bool; async: bool; @@ -451,6 +456,11 @@ table Definition { source_location: string; } +table FuncMapping { + name: string; + index: int32; +} + // Halide::internal::Function table Func { name: string; @@ -466,12 +476,16 @@ table Func { // output_buffers: [Parameter]; // extern_arguments: [ExternFuncArgument]; // extern_function_name: string; + // extern_mangling: NameMangling; + // extern_device_api: DeviceAPI; + // extern_proxy_expr: Expr; } table Pipeline { outputs: [Func]; requirements: [Stmt]; // trace_pipeline: bool; + func_mappings: [FuncMapping]; } root_type Pipeline;