Skip to content

Commit

Permalink
sync commit
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 19, 2023
1 parent 478970e commit 4140ae1
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
22 changes: 20 additions & 2 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <fstream>
#include <iostream>

using Halide::Serialize::TypeCode;

std::string Deserializer::deserialize_string(const flatbuffers::String *str) {
return str->str();
Expand Down Expand Up @@ -35,7 +34,7 @@ Halide::MemoryType Deserializer::deserialize_memory_type(const Halide::Serialize
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
// bits
using Halide::Serialize::TypeCode;
int bits = type->bits();
int lanes = type->lanes();
TypeCode code_deserialized = type->code();
Expand Down Expand Up @@ -622,6 +621,22 @@ Halide::Internal::Definition Deserializer::deserialize_definition(const Halide::
return Halide::Internal::Definition(is_init, predicate, args, values, Halide::Internal::StageSchedule(), specializations, source_location);
}

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
return std::map<std::string, Halide::Internal::FunctionPtr>();
}

std::map<std::string, int32_t> Deserializer::deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings) {
std::map<std::string, int32_t> result;
for (const auto &func_mapping : *func_mappings) {
auto name = deserialize_string(func_mapping->name());
auto index = func_mapping->index();
result[name] = index;
}
return result;
}

Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
// unpack binary file
std::ifstream in(filename, std::ios::binary | std::ios::in);
Expand All @@ -637,6 +652,9 @@ Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
in.read(data.data(), size);
in.close();

this->func_mappings_str2idx = deserialize_func_mappings(Halide::Serialize::GetPipeline(data.data())->func_mappings());
this->func_mappings_idx2ptr = reconstruct_func_ptr_mappings();

const auto *pipeline_obj = Halide::Serialize::GetPipeline(data.data());
const auto *func_objs = pipeline_obj->outputs();
std::vector<Halide::Func> funcs;
Expand Down
7 changes: 7 additions & 0 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class Deserializer {
Halide::Pipeline deserialize(const std::string &filename);

private:
std::map<std::string, int32_t> func_mappings_str2idx;
std::map<int32_t, Halide::Internal::FunctionPtr> func_mappings_idx2ptr;

// helper functions to deserialize each type of object
Halide::MemoryType deserialize_memory_type(const Halide::Serialize::MemoryType memory_type);

Expand Down Expand Up @@ -42,6 +45,10 @@ class Deserializer {
Halide::Internal::Specialization deserialize_specialization(const Halide::Serialize::Specialization *specialization);

Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition);

std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
};

#endif
40 changes: 38 additions & 2 deletions apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ flatbuffers::Offset<Halide::Serialize::FuncSchedule> Serializer::serialize_func_
for (const auto &estimate : func_schedule.estimates()) {
estimates_serialized.push_back(serialize_bound(builder, estimate));
}
auto wrappers_serialized = serialize_wrapper_refs(builder, func_schedule.wrappers());
// TODO: make this a func
Halide::Serialize::MemoryType memory_type = Halide::Serialize::MemoryType::MemoryType_Auto;
switch (func_schedule.memory_type()) {
Expand Down Expand Up @@ -822,7 +823,7 @@ flatbuffers::Offset<Halide::Serialize::FuncSchedule> Serializer::serialize_func_
auto memoize_eviction_key = func_schedule.memoize_eviction_key();
auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key);
return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized),
builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
builder.CreateVector(estimates_serialized), builder.CreateVector(wrappers_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
}

flatbuffers::Offset<Halide::Serialize::Specialization> Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) {
Expand Down Expand Up @@ -862,10 +863,34 @@ flatbuffers::Offset<Halide::Serialize::Definition> Serializer::serialize_definit
builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized);
}

std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> wrapper_refs_serialized;
for (const auto& it : wrappers) {
std::string name = it.first;
const Halide::Internal::FunctionPtr& func_ptr = it.second;
uint64_t func_address = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(func_ptr.get()));
auto name_serialized = serialize_string(builder, name);
wrapper_refs_serialized.push_back(Halide::Serialize::CreateWrapperRef(builder, name_serialized, func_address));
}
return wrapper_refs_serialized;
}

std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> Serializer::serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings) {
std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> func_mappings_serialized;
for (const auto& it : func_mappings) {
std::string name = it.first;
int32_t index = it.second;
auto name_serialized = serialize_string(builder, name);
func_mappings_serialized.push_back(Halide::Serialize::CreateFuncMapping(builder, name_serialized, index));
}
return func_mappings_serialized;
}

void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) {
std::cout << "Serializing a pipeline into " << filename << "\n";
flatbuffers::FlatBufferBuilder builder(1024);
std::map<std::string, Halide::Internal::Function> env;
std::map<std::string, int32_t> func_mappings;

// extract the DAG, unwarp function from Funcs
for (const Halide::Func &func : pipeline.outputs()) {
Expand All @@ -874,6 +899,15 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &
env.insert(more_funcs.begin(), more_funcs.end());
}

// construct the internal func mapping that will be used
// through serialization/deserialization to reassamble the DAG
{
int32_t i = 0;
for (const auto& it: env) {
func_mappings[it.first] = i++;
}
}

// serialize each func
// TODO: this should be the correct way to serialize the whole DAG
// a vector of all funcs + an extra map to map from name to index
Expand Down Expand Up @@ -905,7 +939,9 @@ void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &
auto requirements_vector = builder.CreateVector(requirements_serialized);
auto requirements_types_vector = builder.CreateVector(requirements_types);

auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector);
auto func_mappings_serialized = serialize_func_mappings(builder, func_mappings);

auto pipeline_obj = Halide::Serialize::CreatePipeline(builder, funcs, requirements_types_vector, requirements_vector, builder.CreateVector(func_mappings_serialized));
builder.Finish(pipeline_obj);

// write the binary file
Expand Down
5 changes: 5 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>
#include <utility>
#include <vector>
#include <map>

#include "Halide.h"
#include "halide_ir_generated.h"
Expand Down Expand Up @@ -41,6 +42,10 @@ class Serializer {
flatbuffers::Offset<Halide::Serialize::Specialization> serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization);

flatbuffers::Offset<Halide::Serialize::Definition> serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition);

std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);

std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
};

#endif
16 changes: 15 additions & 1 deletion apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,18 @@ table LoopLevel {
locked: bool;
}

table WrapperRef {
name: string;
func_address: uint64;
}

table FuncSchedule {
store_level: LoopLevel;
compute_level: LoopLevel;
storage_dims: [StorageDim];
bounds: [Bound];
estimates: [Bound];
// wrappers: [WrapperRef]; TODO: no WrapperRef yet
wrappers: [WrapperRef];
memory_type: MemoryType = Auto;
memoized: bool;
async: bool;
Expand All @@ -451,6 +456,11 @@ table Definition {
source_location: string;
}

table FuncMapping {
name: string;
index: int32;
}

// Halide::internal::Function
table Func {
name: string;
Expand All @@ -466,12 +476,16 @@ table Func {
// output_buffers: [Parameter];
// extern_arguments: [ExternFuncArgument];
// extern_function_name: string;
// extern_mangling: NameMangling;
// extern_device_api: DeviceAPI;
// extern_proxy_expr: Expr;
}

table Pipeline {
outputs: [Func];
requirements: [Stmt];
// trace_pipeline: bool;
func_mappings: [FuncMapping];
}

root_type Pipeline;

0 comments on commit 4140ae1

Please sign in to comment.