Skip to content

Commit

Permalink
temporarily comment out func mapping stuff to remove blockers
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent 4140ae1 commit df3620b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 58 deletions.
40 changes: 25 additions & 15 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,19 +623,30 @@ Halide::Internal::Definition Deserializer::deserialize_definition(const Halide::

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
return std::map<std::string, Halide::Internal::FunctionPtr>();
}
// std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
// return std::map<std::string, Halide::Internal::FunctionPtr>();
// }

std::map<std::string, int32_t> Deserializer::deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings) {
std::map<std::string, int32_t> result;
for (const auto &func_mapping : *func_mappings) {
auto name = deserialize_string(func_mapping->name());
auto index = func_mapping->index();
result[name] = index;
}
return result;
}
// std::map<std::string, int32_t> Deserializer::deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings) {
// std::map<std::string, int32_t> result;
// for (const auto &func_mapping : *func_mappings) {
// auto name = deserialize_string(func_mapping->name());
// auto index = func_mapping->index();
// result[name] = index;
// }
// return result;
// }

// std::map<int32_t, Halide::Internal::FunctionPtr> Deserializer::reconstruct_func_ptr_mappings() {
// std::map<int32_t, Halide::Internal::FunctionPtr> result;
// for (const auto &mapping : this->func_mappings_str2idx) {
// auto name = mapping.first;
// auto index = mapping.second;
// auto func_ptr = this->func_mappings_idx2ptr[index];
// result[index] = func_ptr;
// }
// return result;
// }

Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
// unpack binary file
Expand All @@ -652,10 +663,9 @@ Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
in.read(data.data(), size);
in.close();

this->func_mappings_str2idx = deserialize_func_mappings(Halide::Serialize::GetPipeline(data.data())->func_mappings());
this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings();

const auto *pipeline_obj = Halide::Serialize::GetPipeline(data.data());
// this->func_mappings_str2idx = deserialize_func_mappings(pipeline_obj->func_mappings());
// this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings();
const auto *func_objs = pipeline_obj->outputs();
std::vector<Halide::Func> funcs;
funcs.reserve(func_objs->size());
Expand Down
6 changes: 4 additions & 2 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ class Deserializer {

Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition);

std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);
// std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
// std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);

// std::map<int32_t, Halide::Internal::FunctionPtr> reconstruct_func_ptr_mappings();
};

#endif
73 changes: 41 additions & 32 deletions apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ flatbuffers::Offset<Halide::Serialize::FuncSchedule> Serializer::serialize_func_
for (const auto &estimate : func_schedule.estimates()) {
estimates_serialized.push_back(serialize_bound(builder, estimate));
}
auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers());
// auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers());
// TODO: make this a func
Halide::Serialize::MemoryType memory_type = Halide::Serialize::MemoryType::MemoryType_Auto;
switch (func_schedule.memory_type()) {
Expand Down Expand Up @@ -823,7 +823,7 @@ flatbuffers::Offset<Halide::Serialize::FuncSchedule> Serializer::serialize_func_
auto memoize_eviction_key = func_schedule.memoize_eviction_key();
auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key);
return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized),
builder.CreateVector(estimates_serialized), builder.CreateVector(wrappers_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
}

flatbuffers::Offset<Halide::Serialize::Specialization> Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) {
Expand Down Expand Up @@ -863,28 +863,36 @@ flatbuffers::Offset<Halide::Serialize::Definition> Serializer::serialize_definit
builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized);
}

std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> wrapper_refs_serialized;
for (const auto& it : wrappers) {
std::string name = it.first;
const Halide::Internal::FunctionPtr& func_ptr = it.second;
uint64_t func_address = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(func_ptr.get()));
auto name_serialized = serialize_string(builder, name);
wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_address));
}
return wrapper_refs_serialized;
}

std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings) {
std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> func_mappings_serialized;
for (const auto& it : func_mappings) {
std::string name = it.first;
int32_t index = it.second;
auto name_serialized = serialize_string(builder, name);
func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index));
}
return func_mappings_serialized;
}
// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
// // instead of storing the function pointer or raw function address,
// // we store a pre-computed function index as the serialized format for WrapperRef
// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> wrapper_refs_serialized;
// for (const auto& it : wrappers) {
// std::string name = it.first;
// const Halide::Internal::FunctionPtr& func_ptr = it.second;
// // TODO: is `name` and `Function(it.second).name()` the same thing?
// if (auto fm_it = this->func_mappings.find(Halide::Internal::Function(it.second).name()); it != this->func_mappings.end()) {
// int32_t func_idx = fm_it->second;
// auto name_serialized = serialize_string(builder, name);
// wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_idx));
// } else {
// std::cerr << "func " << name << " not found in func_mappings\n";
// exit(1);
// }
// }
// return wrapper_refs_serialized;
// }

// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings) {
// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> func_mappings_serialized;
// for (const auto& it : func_mappings) {
// std::string name = it.first;
// int32_t index = it.second;
// auto name_serialized = serialize_string(builder, name);
// func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index));
// }
// return func_mappings_serialized;
// }

void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) {
std::cout << "Serializing a pipeline into " << filename << "\n";
Expand All @@ -901,12 +909,13 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &

// construct the internal func mapping that will be used
// through serialization/deserialization to reassamble the DAG
{
int32_t i = 0;
for (const auto& it: env) {
func_mappings[it.first] = i++;
}
}
// {
// int32_t i = 0;
// for (const auto& it: env) {
// func_mappings[it.first] = i++;
// }
// this->func_mappings = func_mappings;
// }

// serialize each func
// TODO: this should be the correct way to serialize the whole DAG
Expand Down Expand Up @@ -939,9 +948,9 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &
auto requirements_vector = builder.CreateVector(requirements_serialized);
auto requirements_types_vector = builder.CreateVector(requirements_types);

auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings);
// auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings);

auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector, builder.CreateVector(func_mappings_serialized));
auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector);
builder.Finish(pipeline_obj);

// write the binary file
Expand Down
6 changes: 4 additions & 2 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Serializer {
void serialize(const Halide::Pipeline &pipeline, const std::string &filename);

private:
// std::map<std::string, int32_t> func_mappings;

// helper functions to serialize each type of object
flatbuffers::Offset<flatbuffers::String> serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);

Expand Down Expand Up @@ -43,9 +45,9 @@ class Serializer {

flatbuffers::Offset<Halide::Serialize::Definition> serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition);

std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);
// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);

std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
};

#endif
14 changes: 7 additions & 7 deletions apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ table Prefetch {
name: string;
types: [Type];
bounds: [Range];
// prefetch: PrefetchDirective; TODO: no PrefetchDirective yet
// prefetch: PrefetchDirective;
condition: Expr;
body: Stmt;
}
Expand Down Expand Up @@ -422,18 +422,18 @@ table LoopLevel {
locked: bool;
}

table WrapperRef {
name: string;
func_address: uint64;
}
// table WrapperRef {
// name: string;
// func_idx: int32;
// }

table FuncSchedule {
store_level: LoopLevel;
compute_level: LoopLevel;
storage_dims: [StorageDim];
bounds: [Bound];
estimates: [Bound];
wrappers: [WrapperRef];
// wrappers: [WrapperRef];
memory_type: MemoryType = Auto;
memoized: bool;
async: bool;
Expand Down Expand Up @@ -485,7 +485,7 @@ table Pipeline {
outputs: [Func];
requirements: [Stmt];
// trace_pipeline: bool;
func_mappings: [FuncMapping];
// func_mappings: [FuncMapping];
}

root_type Pipeline;

0 comments on commit df3620b

Please sign in to comment.