diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp
index 4228316b6311..7853e2d43e8f 100644
--- a/apps/serdes/Deserializer.cpp
+++ b/apps/serdes/Deserializer.cpp
@@ -623,19 +623,30 @@ Halide::Internal::Definition Deserializer::deserialize_definition(const Halide::
// TODO: will need to serialize a reverse table of map
to
// later reconstruct a map of find out which function ptrs to use here
-std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) {
- return std::map();
-}
+// std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) {
+// return std::map();
+// }
-std::map Deserializer::deserialize_func_mappings(const flatbuffers::Vector> *func_mappings) {
- std::map result;
- for (const auto &func_mapping : *func_mappings) {
- auto name = deserialize_string(func_mapping->name());
- auto index = func_mapping->index();
- result[name] = index;
- }
- return result;
-}
+// std::map Deserializer::deserialize_func_mappings(const flatbuffers::Vector> *func_mappings) {
+// std::map result;
+// for (const auto &func_mapping : *func_mappings) {
+// auto name = deserialize_string(func_mapping->name());
+// auto index = func_mapping->index();
+// result[name] = index;
+// }
+// return result;
+// }
+
+// std::map Deserializer::reconstruct_func_ptr_mappings() {
+// std::map result;
+// for (const auto &mapping : this->func_mappings_str2idx) {
+// auto name = mapping.first;
+// auto index = mapping.second;
+// auto func_ptr = this->func_mappings_idx2ptr[index];
+// result[index] = func_ptr;
+// }
+// return result;
+// }
Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
// unpack binary file
@@ -652,10 +663,9 @@ Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
in.read(data.data(), size);
in.close();
- this->func_mappings_str2idx = deserialize_func_mappings(Halide::Serialize::GetPipeline(data.data())->func_mappings());
- this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings();
-
const auto *pipeline_obj = Halide::Serialize::GetPipeline(data.data());
+ // this->func_mappings_str2idx = deserialize_func_mappings(pipeline_obj->func_mappings());
+ // this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings();
const auto *func_objs = pipeline_obj->outputs();
std::vector funcs;
funcs.reserve(func_objs->size());
diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h
index e1a6ae55410f..af0676b106b9 100644
--- a/apps/serdes/Deserializer.h
+++ b/apps/serdes/Deserializer.h
@@ -46,9 +46,11 @@ class Deserializer {
Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition);
- std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs);
+ // std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs);
- std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings);
+ // std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings);
+
+ // std::map reconstruct_func_ptr_mappings();
};
#endif
diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp
index b177476dd07a..05ccfa40a8e0 100644
--- a/apps/serdes/Serializer.cpp
+++ b/apps/serdes/Serializer.cpp
@@ -777,7 +777,7 @@ flatbuffers::Offset Serializer::serialize_func_
for (const auto &estimate : func_schedule.estimates()) {
estimates_serialized.push_back(serialize_bound(builder, estimate));
}
- auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers());
+ // auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers());
// TODO: make this a func
Halide::Serialize::MemoryType memory_type = Halide::Serialize::MemoryType::MemoryType_Auto;
switch (func_schedule.memory_type()) {
@@ -823,7 +823,7 @@ flatbuffers::Offset Serializer::serialize_func_
auto memoize_eviction_key = func_schedule.memoize_eviction_key();
auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key);
return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized),
- builder.CreateVector(estimates_serialized), builder.CreateVector(wrappers_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
+ builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
}
flatbuffers::Offset Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) {
@@ -863,28 +863,36 @@ flatbuffers::Offset Serializer::serialize_definit
builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized);
}
-std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) {
- std::vector> wrapper_refs_serialized;
- for (const auto& it : wrappers) {
- std::string name = it.first;
- const Halide::Internal::FunctionPtr& func_ptr = it.second;
- uint64_t func_address = static_cast(reinterpret_cast(func_ptr.get()));
- auto name_serialized = serialize_string(builder, name);
- wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_address));
- }
- return wrapper_refs_serialized;
-}
-
-std::vector> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings) {
- std::vector> func_mappings_serialized;
- for (const auto& it : func_mappings) {
- std::string name = it.first;
- int32_t index = it.second;
- auto name_serialized = serialize_string(builder, name);
- func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index));
- }
- return func_mappings_serialized;
-}
+// std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) {
+// // instead of storing the function pointer or raw function address,
+// // we store a pre-computed function index as the serialized format for WrapperRef
+// std::vector> wrapper_refs_serialized;
+// for (const auto& it : wrappers) {
+// std::string name = it.first;
+// const Halide::Internal::FunctionPtr& func_ptr = it.second;
+// // TODO: is `name` and `Function(it.second).name()` the same thing?
+// if (auto fm_it = this->func_mappings.find(Halide::Internal::Function(it.second).name()); it != this->func_mappings.end()) {
+// int32_t func_idx = fm_it->second;
+// auto name_serialized = serialize_string(builder, name);
+// wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_idx));
+// } else {
+// std::cerr << "func " << name << " not found in func_mappings\n";
+// exit(1);
+// }
+// }
+// return wrapper_refs_serialized;
+// }
+
+// std::vector> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings) {
+// std::vector> func_mappings_serialized;
+// for (const auto& it : func_mappings) {
+// std::string name = it.first;
+// int32_t index = it.second;
+// auto name_serialized = serialize_string(builder, name);
+// func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index));
+// }
+// return func_mappings_serialized;
+// }
void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) {
std::cout << "Serializing a pipeline into " << filename << "\n";
@@ -901,12 +909,13 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &
// construct the internal func mapping that will be used
// through serialization/deserialization to reassamble the DAG
- {
- int32_t i = 0;
- for (const auto& it: env) {
- func_mappings[it.first] = i++;
- }
- }
+ // {
+ // int32_t i = 0;
+ // for (const auto& it: env) {
+ // func_mappings[it.first] = i++;
+ // }
+ // this->func_mappings = func_mappings;
+ // }
// serialize each func
// TODO: this should be the correct way to serialize the whole DAG
@@ -939,9 +948,9 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &
auto requirements_vector = builder.CreateVector(requirements_serialized);
auto requirements_types_vector = builder.CreateVector(requirements_types);
- auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings);
+ // auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings);
- auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector, builder.CreateVector(func_mappings_serialized));
+ auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector);
builder.Finish(pipeline_obj);
// write the binary file
diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h
index 6cc0867488aa..65b40c646f9a 100644
--- a/apps/serdes/Serializer.h
+++ b/apps/serdes/Serializer.h
@@ -16,6 +16,8 @@ class Serializer {
void serialize(const Halide::Pipeline &pipeline, const std::string &filename);
private:
+ // std::map func_mappings;
+
// helper functions to serialize each type of object
flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);
@@ -43,9 +45,9 @@ class Serializer {
flatbuffers::Offset serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition);
- std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers);
+ // std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers);
- std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings);
+ // std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings);
};
#endif
diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs
index f58826c743c1..ea38c5a99a49 100644
--- a/apps/serdes/halide_ir.fbs
+++ b/apps/serdes/halide_ir.fbs
@@ -192,7 +192,7 @@ table Prefetch {
name: string;
types: [Type];
bounds: [Range];
- // prefetch: PrefetchDirective; TODO: no PrefetchDirective yet
+ // prefetch: PrefetchDirective;
condition: Expr;
body: Stmt;
}
@@ -422,10 +422,10 @@ table LoopLevel {
locked: bool;
}
-table WrapperRef {
- name: string;
- func_address: uint64;
-}
+// table WrapperRef {
+// name: string;
+// func_idx: int32;
+// }
table FuncSchedule {
store_level: LoopLevel;
@@ -433,7 +433,7 @@ table FuncSchedule {
storage_dims: [StorageDim];
bounds: [Bound];
estimates: [Bound];
- wrappers: [WrapperRef];
+ // wrappers: [WrapperRef];
memory_type: MemoryType = Auto;
memoized: bool;
async: bool;
@@ -485,7 +485,7 @@ table Pipeline {
outputs: [Func];
requirements: [Stmt];
// trace_pipeline: bool;
- func_mappings: [FuncMapping];
+ // func_mappings: [FuncMapping];
}
root_type Pipeline;