From df3620bb310b3797966215ee00645ae746f57c38 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 20 Jun 2023 09:28:15 -0700 Subject: [PATCH] temporarily comment out func mapping stuff to remove blockers --- apps/serdes/Deserializer.cpp | 40 ++++++++++++-------- apps/serdes/Deserializer.h | 6 ++- apps/serdes/Serializer.cpp | 73 ++++++++++++++++++++---------------- apps/serdes/Serializer.h | 6 ++- apps/serdes/halide_ir.fbs | 14 +++---- 5 files changed, 81 insertions(+), 58 deletions(-) diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp index 4228316b6311..7853e2d43e8f 100644 --- a/apps/serdes/Deserializer.cpp +++ b/apps/serdes/Deserializer.cpp @@ -623,19 +623,30 @@ Halide::Internal::Definition Deserializer::deserialize_definition(const Halide:: // TODO: will need to serialize a reverse table of map to // later reconstruct a map of find out which function ptrs to use here -std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) { - return std::map(); -} +// std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) { +// return std::map(); +// } -std::map Deserializer::deserialize_func_mappings(const flatbuffers::Vector> *func_mappings) { - std::map result; - for (const auto &func_mapping : *func_mappings) { - auto name = deserialize_string(func_mapping->name()); - auto index = func_mapping->index(); - result[name] = index; - } - return result; -} +// std::map Deserializer::deserialize_func_mappings(const flatbuffers::Vector> *func_mappings) { +// std::map result; +// for (const auto &func_mapping : *func_mappings) { +// auto name = deserialize_string(func_mapping->name()); +// auto index = func_mapping->index(); +// result[name] = index; +// } +// return result; +// } + +// std::map Deserializer::reconstruct_func_ptr_mappings() { +// std::map result; +// for (const auto &mapping : this->func_mappings_str2idx) { +// auto name = mapping.first; +// auto index = mapping.second; +// auto func_ptr = this->func_mappings_idx2ptr[index]; +// result[index] = func_ptr; +// } +// return result; +// } Halide::Pipeline Deserializer::deserialize(const std::string &filename) { // unpack binary file @@ -652,10 +663,9 @@ Halide::Pipeline Deserializer::deserialize(const std::string &filename) { in.read(data.data(), size); in.close(); - this->func_mappings_str2idx = deserialize_func_mappings(Halide::Serialize::GetPipeline(data.data())->func_mappings()); - this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings(); - const auto *pipeline_obj = Halide::Serialize::GetPipeline(data.data()); + // this->func_mappings_str2idx = deserialize_func_mappings(pipeline_obj->func_mappings()); + // this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings(); const auto *func_objs = pipeline_obj->outputs(); std::vector funcs; funcs.reserve(func_objs->size()); diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h index e1a6ae55410f..af0676b106b9 100644 --- a/apps/serdes/Deserializer.h +++ b/apps/serdes/Deserializer.h @@ -46,9 +46,11 @@ class Deserializer { Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition); - std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs); + // std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs); - std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings); + // std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings); + + // std::map reconstruct_func_ptr_mappings(); }; #endif diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp index b177476dd07a..05ccfa40a8e0 100644 --- a/apps/serdes/Serializer.cpp +++ b/apps/serdes/Serializer.cpp @@ -777,7 +777,7 @@ flatbuffers::Offset Serializer::serialize_func_ for (const auto &estimate : func_schedule.estimates()) { estimates_serialized.push_back(serialize_bound(builder, estimate)); } - auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers()); + // auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers()); // TODO: make this a func Halide::Serialize::MemoryType memory_type = Halide::Serialize::MemoryType::MemoryType_Auto; switch (func_schedule.memory_type()) { @@ -823,7 +823,7 @@ flatbuffers::Offset Serializer::serialize_func_ auto memoize_eviction_key = func_schedule.memoize_eviction_key(); auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key); return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized), - builder.CreateVector(estimates_serialized), builder.CreateVector(wrappers_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second); + builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second); } flatbuffers::Offset Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) { @@ -863,28 +863,36 @@ flatbuffers::Offset Serializer::serialize_definit builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized); } -std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) { - std::vector> wrapper_refs_serialized; - for (const auto& it : wrappers) { - std::string name = it.first; - const Halide::Internal::FunctionPtr& func_ptr = it.second; - uint64_t func_address = static_cast(reinterpret_cast(func_ptr.get())); - auto name_serialized = serialize_string(builder, name); - wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_address)); - } - return wrapper_refs_serialized; -} - -std::vector> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings) { - std::vector> func_mappings_serialized; - for (const auto& it : func_mappings) { - std::string name = it.first; - int32_t index = it.second; - auto name_serialized = serialize_string(builder, name); - func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index)); - } - return func_mappings_serialized; -} +// std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) { +// // instead of storing the function pointer or raw function address, +// // we store a pre-computed function index as the serialized format for WrapperRef +// std::vector> wrapper_refs_serialized; +// for (const auto& it : wrappers) { +// std::string name = it.first; +// const Halide::Internal::FunctionPtr& func_ptr = it.second; +// // TODO: is `name` and `Function(it.second).name()` the same thing? +// if (auto fm_it = this->func_mappings.find(Halide::Internal::Function(it.second).name()); it != this->func_mappings.end()) { +// int32_t func_idx = fm_it->second; +// auto name_serialized = serialize_string(builder, name); +// wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_idx)); +// } else { +// std::cerr << "func " << name << " not found in func_mappings\n"; +// exit(1); +// } +// } +// return wrapper_refs_serialized; +// } + +// std::vector> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings) { +// std::vector> func_mappings_serialized; +// for (const auto& it : func_mappings) { +// std::string name = it.first; +// int32_t index = it.second; +// auto name_serialized = serialize_string(builder, name); +// func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index)); +// } +// return func_mappings_serialized; +// } void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) { std::cout << "Serializing a pipeline into " << filename << "\n"; @@ -901,12 +909,13 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string & // construct the internal func mapping that will be used // through serialization/deserialization to reassamble the DAG - { - int32_t i = 0; - for (const auto& it: env) { - func_mappings[it.first] = i++; - } - } + // { + // int32_t i = 0; + // for (const auto& it: env) { + // func_mappings[it.first] = i++; + // } + // this->func_mappings = func_mappings; + // } // serialize each func // TODO: this should be the correct way to serialize the whole DAG @@ -939,9 +948,9 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string & auto requirements_vector = builder.CreateVector(requirements_serialized); auto requirements_types_vector = builder.CreateVector(requirements_types); - auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings); + // auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings); - auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector, builder.CreateVector(func_mappings_serialized)); + auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector); builder.Finish(pipeline_obj); // write the binary file diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h index 6cc0867488aa..65b40c646f9a 100644 --- a/apps/serdes/Serializer.h +++ b/apps/serdes/Serializer.h @@ -16,6 +16,8 @@ class Serializer { void serialize(const Halide::Pipeline &pipeline, const std::string &filename); private: + // std::map func_mappings; + // helper functions to serialize each type of object flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str); @@ -43,9 +45,9 @@ class Serializer { flatbuffers::Offset serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition); - std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers); + // std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers); - std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings); + // std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings); }; #endif diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs index f58826c743c1..ea38c5a99a49 100644 --- a/apps/serdes/halide_ir.fbs +++ b/apps/serdes/halide_ir.fbs @@ -192,7 +192,7 @@ table Prefetch { name: string; types: [Type]; bounds: [Range]; - // prefetch: PrefetchDirective; TODO: no PrefetchDirective yet + // prefetch: PrefetchDirective; condition: Expr; body: Stmt; } @@ -422,10 +422,10 @@ table LoopLevel { locked: bool; } -table WrapperRef { - name: string; - func_address: uint64; -} +// table WrapperRef { +// name: string; +// func_idx: int32; +// } table FuncSchedule { store_level: LoopLevel; @@ -433,7 +433,7 @@ table FuncSchedule { storage_dims: [StorageDim]; bounds: [Bound]; estimates: [Bound]; - wrappers: [WrapperRef]; + // wrappers: [WrapperRef]; memory_type: MemoryType = Auto; memoized: bool; async: bool; @@ -485,7 +485,7 @@ table Pipeline { outputs: [Func]; requirements: [Stmt]; // trace_pipeline: bool; - func_mappings: [FuncMapping]; + // func_mappings: [FuncMapping]; } root_type Pipeline;