Skip to content

Commit

Permalink
split
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent 097ec2e commit 88f4a5d
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 7 deletions.
53 changes: 53 additions & 0 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,42 @@ Halide::NameMangling Deserializer::deserialize_name_mangling(const Halide::Seria
}
}

Halide::TailStrategy Deserializer::deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy) {
switch (tail_strategy) {
case Halide::Serialize::TailStrategy::TailStrategy_RoundUp:
return Halide::TailStrategy::RoundUp;
case Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf:
return Halide::TailStrategy::GuardWithIf;
case Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads:
return Halide::TailStrategy::PredicateLoads;
case Halide::Serialize::TailStrategy::TailStrategy_PredicateStores:
return Halide::TailStrategy::PredicateStores;
case Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards:
return Halide::TailStrategy::ShiftInwards;
case Halide::Serialize::TailStrategy::TailStrategy_Auto:
return Halide::TailStrategy::Auto;
default:
std::cerr << "unknown tail strategy " << tail_strategy << "\n";
exit(1);
}
}

Halide::Internal::Split::SplitType Deserializer::deserialize_split_type(const Halide::Serialize::SplitType split_type) {
switch (split_type) {
case Halide::Serialize::SplitType::SplitType_SplitVar:
return Halide::Internal::Split::SplitType::SplitVar;
case Halide::Serialize::SplitType::SplitType_RenameVar:
return Halide::Internal::Split::SplitType::RenameVar;
case Halide::Serialize::SplitType::SplitType_FuseVars:
return Halide::Internal::Split::SplitType::FuseVars;
case Halide::Serialize::SplitType::SplitType_PurifyRVar:
return Halide::Internal::Split::SplitType::PurifyRVar;
default:
std::cerr << "unknown split type " << split_type << "\n";
exit(1);
}
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
Expand Down Expand Up @@ -743,6 +779,23 @@ Halide::Internal::PrefetchDirective Deserializer::deserialize_prefetch_directive
return hl_prefetch_directive;
}

Halide::Internal::Split Deserializer::deserialize_split(const Halide::Serialize::Split *split) {
auto old_var = deserialize_string(split->old_var());
auto outer = deserialize_string(split->outer());
auto inner = deserialize_string(split->inner());
auto factor = deserialize_expr(split->factor_type(), split->factor());
auto tail = deserialize_tail_strategy(split->tail());
auto split_type = deserialize_split_type(split->split_type());
auto hl_split = Halide::Internal::Split();
hl_split.old_var = old_var;
hl_split.outer = outer;
hl_split.inner = inner;
hl_split.factor = factor;
hl_split.tail = tail;
hl_split.split_type = split_type;
return hl_split;
}

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
// std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
Expand Down
5 changes: 5 additions & 0 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Deserializer {

Halide::NameMangling deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling);

Halide::TailStrategy deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy);

Halide::Internal::Split::SplitType deserialize_split_type(const Halide::Serialize::SplitType split_type);

std::string deserialize_string(const flatbuffers::String *str);

Halide::Type deserialize_type(const Halide::Serialize::Type *type);
Expand Down Expand Up @@ -66,6 +70,7 @@ class Deserializer {

Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive);

Halide::Internal::Split deserialize_split(const Halide::Serialize::Split *split);
// std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

// std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
Expand Down
50 changes: 48 additions & 2 deletions apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,42 @@ Halide::Serialize::NameMangling Serializer::serialize_name_mangling(const Halide
}
}

Halide::Serialize::TailStrategy Serializer::serialize_tail_strategy(const Halide::TailStrategy &tail_strategy) {
switch (tail_strategy) {
case Halide::TailStrategy::RoundUp:
return Halide::Serialize::TailStrategy::TailStrategy_RoundUp;
case Halide::TailStrategy::GuardWithIf:
return Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf;
case Halide::TailStrategy::PredicateLoads:
return Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads;
case Halide::TailStrategy::PredicateStores:
return Halide::Serialize::TailStrategy::TailStrategy_PredicateStores;
case Halide::TailStrategy::ShiftInwards:
return Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards;
case Halide::TailStrategy::Auto:
return Halide::Serialize::TailStrategy::TailStrategy_Auto;
default:
std::cerr << "Unsupported tail strategy\n";
exit(1);
}
}

Halide::Serialize::SplitType Serializer::serialize_split_type(const Halide::Internal::Split::SplitType &split_type) {
switch (split_type) {
case Halide::Internal::Split::SplitType::SplitVar:
return Halide::Serialize::SplitType::SplitType_SplitVar;
case Halide::Internal::Split::SplitType::RenameVar:
return Halide::Serialize::SplitType::SplitType_RenameVar;
case Halide::Internal::Split::SplitType::FuseVars:
return Halide::Serialize::SplitType::SplitType_FuseVars;
case Halide::Internal::Split::SplitType::PurifyRVar:
return Halide::Serialize::SplitType::SplitType_PurifyRVar;
default:
std::cerr << "Unsupported split type\n";
exit(1);
}
}

flatbuffers::Offset<flatbuffers::String> Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) {
return builder.CreateString(str);
}
Expand Down Expand Up @@ -616,11 +652,11 @@ flatbuffers::Offset<Halide::Serialize::Func> Serializer::serialize_function(flat
bool trace_realizations = function.is_tracing_realizations();
std::vector<flatbuffers::Offset<flatbuffers::String>> trace_tags_serialized;
trace_tags_serialized.reserve(function.get_trace_tags().size());
for (const auto& tag: function.get_trace_tags()) {
for (const auto &tag : function.get_trace_tags()) {
trace_tags_serialized.push_back(serialize_string(builder, tag));
}
bool frozen = function.frozen();
auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim,
auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim,
args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized), debug_file_serialized,
extern_function_name_serialized, extern_mangling_serialized, extern_function_device_api_serialized, extern_proxy_expr_serialized.first,
extern_proxy_expr_serialized.second, trace_loads, trace_stores, trace_realizations, builder.CreateVector(trace_tags_serialized), frozen);
Expand Down Expand Up @@ -770,6 +806,16 @@ flatbuffers::Offset<Halide::Serialize::PrefetchDirective> Serializer::serialize_
return Halide::Serialize::CreatePrefetchDirective(builder, name_serialized, at_serialized, from_serialized, offset_serialized.first, offset_serialized.second, strategy_serialized);
}

flatbuffers::Offset<Halide::Serialize::Split> Serializer::serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split) {
auto old_var_serialized = serialize_string(builder, split.old_var);
auto outer_serialized = serialize_string(builder, split.outer);
auto inner_serialized = serialize_string(builder, split.inner);
auto factor_serialized = serialize_expr(builder, split.factor);
auto tail_serialized = serialize_tail_strategy(split.tail);
auto inner_to_outer_serialized = serialize_split_type(split.split_type);
return Halide::Serialize::CreateSplit(builder, old_var_serialized, outer_serialized, inner_serialized, factor_serialized.first, factor_serialized.second, tail_serialized, inner_to_outer_serialized);
}

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
// // instead of storing the function pointer or raw function address,
// // we store a pre-computed function index as the serialized format for WrapperRef
Expand Down
6 changes: 6 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class Serializer {

Halide::Serialize::NameMangling serialize_name_mangling(const Halide::NameMangling &name_mangling);

Halide::Serialize::TailStrategy serialize_tail_strategy(const Halide::TailStrategy &tail_strategy);

Halide::Serialize::SplitType serialize_split_type(const Halide::Internal::Split::SplitType &split_type);

flatbuffers::Offset<flatbuffers::String> serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);

flatbuffers::Offset<Halide::Serialize::Type> serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type);
Expand Down Expand Up @@ -67,6 +71,8 @@ class Serializer {

flatbuffers::Offset<Halide::Serialize::PrefetchDirective> serialize_prefetch_directive(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::PrefetchDirective &prefetch_directive);

flatbuffers::Offset<Halide::Serialize::Split> serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split);

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);

// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
Expand Down
66 changes: 66 additions & 0 deletions apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,72 @@ table Specialization {
failure_message: string;
}

enum TailStrategy: ubyte {
RoundUp,
GuardWithIf,
PredicateLoads,
PredicateStores,
ShiftInwards,
Auto,
}

enum SplitType: ubyte {
SplitVar,
RenameVar,
FuseVars,
PurifyRVar,
}

table Split {
old_var: string;
outer: string;
inner: string;
factor: Expr;
tail: TailStrategy;
split_type: SplitType;
}

// table Dim {
// var: string;
// for_type: ForType;
// device_api: DeviceAPI;
// dim_type: DimType;
// }

// enum LoopAlignStrategy: ubyte {
// AlignStart,
// AlignEnd,
// NoAlign,
// Auto,
// }

// table FuseLoopLevel {
// fuse_level: LoopLevel;
// align_dimension_names: [string];
// align_strategies: [LoopAlignStrategy];
// }

// table FusedPair {
// func_1: string;
// func_2: string;
// stage_1: int32;
// stage_2: int32;
// var_name: string;
// }

// table StageSchedule {
// rvars: [ReductionVariable];
// splits: [Split];
// dims: [Dim];
// prefetches: [PrefetchDirective];
// fuse_level: FuseLoopLevel;
// fused_pairs: [FusedPair];
// touched: bool = false;
// allow_race_conditions: bool = false;
// atomic: bool = false;
// override_atomic_associativity_test: bool = false;
// }

table Definition {
is_init: bool;
predicate: Expr;
Expand Down
10 changes: 5 additions & 5 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,11 @@ Function::Function(const std::vector<Type> &required_types, int required_dims, c
}

Function::Function(const std::string &name, const std::string &origin_name, const std::vector<Halide::Type> &output_types,
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates,
const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
const std::vector<std::string> &trace_tags, bool frozen) {
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates,
const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
const std::vector<std::string> &trace_tags, bool frozen) {
contents.strong = new FunctionGroup;
contents.strong->members.resize(1);
contents->name = name;
Expand Down

0 comments on commit 88f4a5d

Please sign in to comment.