diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp
index b297f038d7f8..48694b0dbc27 100644
--- a/apps/serdes/Deserializer.cpp
+++ b/apps/serdes/Deserializer.cpp
@@ -158,6 +158,42 @@ Halide::NameMangling Deserializer::deserialize_name_mangling(const Halide::Seria
}
}
+Halide::TailStrategy Deserializer::deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy) {
+ switch (tail_strategy) {
+ case Halide::Serialize::TailStrategy::TailStrategy_RoundUp:
+ return Halide::TailStrategy::RoundUp;
+ case Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf:
+ return Halide::TailStrategy::GuardWithIf;
+ case Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads:
+ return Halide::TailStrategy::PredicateLoads;
+ case Halide::Serialize::TailStrategy::TailStrategy_PredicateStores:
+ return Halide::TailStrategy::PredicateStores;
+ case Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards:
+ return Halide::TailStrategy::ShiftInwards;
+ case Halide::Serialize::TailStrategy::TailStrategy_Auto:
+ return Halide::TailStrategy::Auto;
+ default:
+ std::cerr << "unknown tail strategy " << tail_strategy << "\n";
+ exit(1);
+ }
+}
+
+Halide::Internal::Split::SplitType Deserializer::deserialize_split_type(const Halide::Serialize::SplitType split_type) {
+ switch (split_type) {
+ case Halide::Serialize::SplitType::SplitType_SplitVar:
+ return Halide::Internal::Split::SplitType::SplitVar;
+ case Halide::Serialize::SplitType::SplitType_RenameVar:
+ return Halide::Internal::Split::SplitType::RenameVar;
+ case Halide::Serialize::SplitType::SplitType_FuseVars:
+ return Halide::Internal::Split::SplitType::FuseVars;
+ case Halide::Serialize::SplitType::SplitType_PurifyRVar:
+ return Halide::Internal::Split::SplitType::PurifyRVar;
+ default:
+ std::cerr << "unknown split type " << split_type << "\n";
+ exit(1);
+ }
+}
+
Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
@@ -743,6 +779,23 @@ Halide::Internal::PrefetchDirective Deserializer::deserialize_prefetch_directive
return hl_prefetch_directive;
}
+Halide::Internal::Split Deserializer::deserialize_split(const Halide::Serialize::Split *split) {
+ auto old_var = deserialize_string(split->old_var());
+ auto outer = deserialize_string(split->outer());
+ auto inner = deserialize_string(split->inner());
+ auto factor = deserialize_expr(split->factor_type(), split->factor());
+ auto tail = deserialize_tail_strategy(split->tail());
+ auto split_type = deserialize_split_type(split->split_type());
+ auto hl_split = Halide::Internal::Split();
+ hl_split.old_var = old_var;
+ hl_split.outer = outer;
+ hl_split.inner = inner;
+ hl_split.factor = factor;
+ hl_split.tail = tail;
+ hl_split.split_type = split_type;
+ return hl_split;
+}
+
// TODO: will need to serialize a reverse table of map
to
// later reconstruct a map of find out which function ptrs to use here
// std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) {
diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h
index 2d37d7ba15c7..2c509b8295a7 100644
--- a/apps/serdes/Deserializer.h
+++ b/apps/serdes/Deserializer.h
@@ -32,6 +32,10 @@ class Deserializer {
Halide::NameMangling deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling);
+ Halide::TailStrategy deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy);
+
+ Halide::Internal::Split::SplitType deserialize_split_type(const Halide::Serialize::SplitType split_type);
+
std::string deserialize_string(const flatbuffers::String *str);
Halide::Type deserialize_type(const Halide::Serialize::Type *type);
@@ -66,6 +70,7 @@ class Deserializer {
Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive);
+ Halide::Internal::Split deserialize_split(const Halide::Serialize::Split *split);
// std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs);
// std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings);
diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp
index d501233f2c58..a3f6262caddd 100644
--- a/apps/serdes/Serializer.cpp
+++ b/apps/serdes/Serializer.cpp
@@ -157,6 +157,42 @@ Halide::Serialize::NameMangling Serializer::serialize_name_mangling(const Halide
}
}
+Halide::Serialize::TailStrategy Serializer::serialize_tail_strategy(const Halide::TailStrategy &tail_strategy) {
+ switch (tail_strategy) {
+ case Halide::TailStrategy::RoundUp:
+ return Halide::Serialize::TailStrategy::TailStrategy_RoundUp;
+ case Halide::TailStrategy::GuardWithIf:
+ return Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf;
+ case Halide::TailStrategy::PredicateLoads:
+ return Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads;
+ case Halide::TailStrategy::PredicateStores:
+ return Halide::Serialize::TailStrategy::TailStrategy_PredicateStores;
+ case Halide::TailStrategy::ShiftInwards:
+ return Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards;
+ case Halide::TailStrategy::Auto:
+ return Halide::Serialize::TailStrategy::TailStrategy_Auto;
+ default:
+ std::cerr << "Unsupported tail strategy\n";
+ exit(1);
+ }
+}
+
+Halide::Serialize::SplitType Serializer::serialize_split_type(const Halide::Internal::Split::SplitType &split_type) {
+ switch (split_type) {
+ case Halide::Internal::Split::SplitType::SplitVar:
+ return Halide::Serialize::SplitType::SplitType_SplitVar;
+ case Halide::Internal::Split::SplitType::RenameVar:
+ return Halide::Serialize::SplitType::SplitType_RenameVar;
+ case Halide::Internal::Split::SplitType::FuseVars:
+ return Halide::Serialize::SplitType::SplitType_FuseVars;
+ case Halide::Internal::Split::SplitType::PurifyRVar:
+ return Halide::Serialize::SplitType::SplitType_PurifyRVar;
+ default:
+ std::cerr << "Unsupported split type\n";
+ exit(1);
+ }
+}
+
flatbuffers::Offset Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) {
return builder.CreateString(str);
}
@@ -616,11 +652,11 @@ flatbuffers::Offset Serializer::serialize_function(flat
bool trace_realizations = function.is_tracing_realizations();
std::vector> trace_tags_serialized;
trace_tags_serialized.reserve(function.get_trace_tags().size());
- for (const auto& tag: function.get_trace_tags()) {
+ for (const auto &tag : function.get_trace_tags()) {
trace_tags_serialized.push_back(serialize_string(builder, tag));
}
bool frozen = function.frozen();
- auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim,
+ auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim,
args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized), debug_file_serialized,
extern_function_name_serialized, extern_mangling_serialized, extern_function_device_api_serialized, extern_proxy_expr_serialized.first,
extern_proxy_expr_serialized.second, trace_loads, trace_stores, trace_realizations, builder.CreateVector(trace_tags_serialized), frozen);
@@ -770,6 +806,16 @@ flatbuffers::Offset Serializer::serialize_
return Halide::Serialize::CreatePrefetchDirective(builder, name_serialized, at_serialized, from_serialized, offset_serialized.first, offset_serialized.second, strategy_serialized);
}
+flatbuffers::Offset Serializer::serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split) {
+ auto old_var_serialized = serialize_string(builder, split.old_var);
+ auto outer_serialized = serialize_string(builder, split.outer);
+ auto inner_serialized = serialize_string(builder, split.inner);
+ auto factor_serialized = serialize_expr(builder, split.factor);
+ auto tail_serialized = serialize_tail_strategy(split.tail);
+ auto inner_to_outer_serialized = serialize_split_type(split.split_type);
+ return Halide::Serialize::CreateSplit(builder, old_var_serialized, outer_serialized, inner_serialized, factor_serialized.first, factor_serialized.second, tail_serialized, inner_to_outer_serialized);
+}
+
// std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) {
// // instead of storing the function pointer or raw function address,
// // we store a pre-computed function index as the serialized format for WrapperRef
diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h
index 496b895f6120..238e9920b556 100644
--- a/apps/serdes/Serializer.h
+++ b/apps/serdes/Serializer.h
@@ -33,6 +33,10 @@ class Serializer {
Halide::Serialize::NameMangling serialize_name_mangling(const Halide::NameMangling &name_mangling);
+ Halide::Serialize::TailStrategy serialize_tail_strategy(const Halide::TailStrategy &tail_strategy);
+
+ Halide::Serialize::SplitType serialize_split_type(const Halide::Internal::Split::SplitType &split_type);
+
flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);
flatbuffers::Offset serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type);
@@ -67,6 +71,8 @@ class Serializer {
flatbuffers::Offset serialize_prefetch_directive(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::PrefetchDirective &prefetch_directive);
+ flatbuffers::Offset serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split);
+
// std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers);
// std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings);
diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs
index 4b298a4bd566..3044b6722d3e 100644
--- a/apps/serdes/halide_ir.fbs
+++ b/apps/serdes/halide_ir.fbs
@@ -498,6 +498,72 @@ table Specialization {
failure_message: string;
}
+enum TailStrategy: ubyte {
+ RoundUp,
+ GuardWithIf,
+ PredicateLoads,
+ PredicateStores,
+ ShiftInwards,
+ Auto,
+}
+
+enum SplitType: ubyte {
+ SplitVar,
+ RenameVar,
+ FuseVars,
+ PurifyRVar,
+}
+
+table Split {
+ old_var: string;
+ outer: string;
+ inner: string;
+ factor: Expr;
+ tail: TailStrategy;
+ split_type: SplitType;
+}
+
+// table Dim {
+// var: string;
+// for_type: ForType;
+// device_api: DeviceAPI;
+// dim_type: DimType;
+// }
+
+// enum LoopAlignStrategy: ubyte {
+// AlignStart,
+// AlignEnd,
+// NoAlign,
+// Auto,
+// }
+
+// table FuseLoopLevel {
+// fuse_level: LoopLevel;
+// align_dimension_names: [string];
+// align_strategies: [LoopAlignStrategy];
+// }
+
+// table FusedPair {
+// func_1: string;
+// func_2: string;
+// stage_1: int32;
+// stage_2: int32;
+// var_name: string;
+// }
+
+// table StageSchedule {
+// rvars: [ReductionVariable];
+// splits: [Split];
+// dims: [Dim];
+// prefetches: [PrefetchDirective];
+// fuse_level: FuseLoopLevel;
+// fused_pairs: [FusedPair];
+// touched: bool = false;
+// allow_race_conditions: bool = false;
+// atomic: bool = false;
+// override_atomic_associativity_test: bool = false;
+// }
+
table Definition {
is_init: bool;
predicate: Expr;
diff --git a/src/Function.cpp b/src/Function.cpp
index 15dbb8fa2ce6..225189e4c46a 100644
--- a/src/Function.cpp
+++ b/src/Function.cpp
@@ -332,11 +332,11 @@ Function::Function(const std::vector &required_types, int required_dims, c
}
Function::Function(const std::string &name, const std::string &origin_name, const std::vector &output_types,
- const std::vector &required_types, int required_dims, const std::vector &args,
- const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates,
- const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
- const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
- const std::vector &trace_tags, bool frozen) {
+ const std::vector &required_types, int required_dims, const std::vector &args,
+ const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates,
+ const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling,
+ const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations,
+ const std::vector &trace_tags, bool frozen) {
contents.strong = new FunctionGroup;
contents.strong->members.resize(1);
contents->name = name;