From 88f4a5d35936965013a5f9491e7006894edcb218 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 20 Jun 2023 16:15:47 -0700 Subject: [PATCH] split --- apps/serdes/Deserializer.cpp | 53 +++++++++++++++++++++++++++++ apps/serdes/Deserializer.h | 5 +++ apps/serdes/Serializer.cpp | 50 +++++++++++++++++++++++++-- apps/serdes/Serializer.h | 6 ++++ apps/serdes/halide_ir.fbs | 66 ++++++++++++++++++++++++++++++++++++ src/Function.cpp | 10 +++--- 6 files changed, 183 insertions(+), 7 deletions(-) diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp index b297f038d7f8..48694b0dbc27 100644 --- a/apps/serdes/Deserializer.cpp +++ b/apps/serdes/Deserializer.cpp @@ -158,6 +158,42 @@ Halide::NameMangling Deserializer::deserialize_name_mangling(const Halide::Seria } } +Halide::TailStrategy Deserializer::deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy) { + switch (tail_strategy) { + case Halide::Serialize::TailStrategy::TailStrategy_RoundUp: + return Halide::TailStrategy::RoundUp; + case Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf: + return Halide::TailStrategy::GuardWithIf; + case Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads: + return Halide::TailStrategy::PredicateLoads; + case Halide::Serialize::TailStrategy::TailStrategy_PredicateStores: + return Halide::TailStrategy::PredicateStores; + case Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards: + return Halide::TailStrategy::ShiftInwards; + case Halide::Serialize::TailStrategy::TailStrategy_Auto: + return Halide::TailStrategy::Auto; + default: + std::cerr << "unknown tail strategy " << tail_strategy << "\n"; + exit(1); + } +} + +Halide::Internal::Split::SplitType Deserializer::deserialize_split_type(const Halide::Serialize::SplitType split_type) { + switch (split_type) { + case Halide::Serialize::SplitType::SplitType_SplitVar: + return Halide::Internal::Split::SplitType::SplitVar; + case Halide::Serialize::SplitType::SplitType_RenameVar: + return Halide::Internal::Split::SplitType::RenameVar; + case Halide::Serialize::SplitType::SplitType_FuseVars: + return Halide::Internal::Split::SplitType::FuseVars; + case Halide::Serialize::SplitType::SplitType_PurifyRVar: + return Halide::Internal::Split::SplitType::PurifyRVar; + default: + std::cerr << "unknown split type " << split_type << "\n"; + exit(1); + } +} + Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) { using Halide::Serialize::TypeCode; int bits = type->bits(); @@ -743,6 +779,23 @@ Halide::Internal::PrefetchDirective Deserializer::deserialize_prefetch_directive return hl_prefetch_directive; } +Halide::Internal::Split Deserializer::deserialize_split(const Halide::Serialize::Split *split) { + auto old_var = deserialize_string(split->old_var()); + auto outer = deserialize_string(split->outer()); + auto inner = deserialize_string(split->inner()); + auto factor = deserialize_expr(split->factor_type(), split->factor()); + auto tail = deserialize_tail_strategy(split->tail()); + auto split_type = deserialize_split_type(split->split_type()); + auto hl_split = Halide::Internal::Split(); + hl_split.old_var = old_var; + hl_split.outer = outer; + hl_split.inner = inner; + hl_split.factor = factor; + hl_split.tail = tail; + hl_split.split_type = split_type; + return hl_split; +} + // TODO: will need to serialize a reverse table of map to // later reconstruct a map of find out which function ptrs to use here // std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) { diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h index 2d37d7ba15c7..2c509b8295a7 100644 --- a/apps/serdes/Deserializer.h +++ b/apps/serdes/Deserializer.h @@ -32,6 +32,10 @@ class Deserializer { Halide::NameMangling deserialize_name_mangling(const Halide::Serialize::NameMangling name_mangling); + Halide::TailStrategy deserialize_tail_strategy(const Halide::Serialize::TailStrategy tail_strategy); + + Halide::Internal::Split::SplitType deserialize_split_type(const Halide::Serialize::SplitType split_type); + std::string deserialize_string(const flatbuffers::String *str); Halide::Type deserialize_type(const Halide::Serialize::Type *type); @@ -66,6 +70,7 @@ class Deserializer { Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive); + Halide::Internal::Split deserialize_split(const Halide::Serialize::Split *split); // std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs); // std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings); diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp index d501233f2c58..a3f6262caddd 100644 --- a/apps/serdes/Serializer.cpp +++ b/apps/serdes/Serializer.cpp @@ -157,6 +157,42 @@ Halide::Serialize::NameMangling Serializer::serialize_name_mangling(const Halide } } +Halide::Serialize::TailStrategy Serializer::serialize_tail_strategy(const Halide::TailStrategy &tail_strategy) { + switch (tail_strategy) { + case Halide::TailStrategy::RoundUp: + return Halide::Serialize::TailStrategy::TailStrategy_RoundUp; + case Halide::TailStrategy::GuardWithIf: + return Halide::Serialize::TailStrategy::TailStrategy_GuardWithIf; + case Halide::TailStrategy::PredicateLoads: + return Halide::Serialize::TailStrategy::TailStrategy_PredicateLoads; + case Halide::TailStrategy::PredicateStores: + return Halide::Serialize::TailStrategy::TailStrategy_PredicateStores; + case Halide::TailStrategy::ShiftInwards: + return Halide::Serialize::TailStrategy::TailStrategy_ShiftInwards; + case Halide::TailStrategy::Auto: + return Halide::Serialize::TailStrategy::TailStrategy_Auto; + default: + std::cerr << "Unsupported tail strategy\n"; + exit(1); + } +} + +Halide::Serialize::SplitType Serializer::serialize_split_type(const Halide::Internal::Split::SplitType &split_type) { + switch (split_type) { + case Halide::Internal::Split::SplitType::SplitVar: + return Halide::Serialize::SplitType::SplitType_SplitVar; + case Halide::Internal::Split::SplitType::RenameVar: + return Halide::Serialize::SplitType::SplitType_RenameVar; + case Halide::Internal::Split::SplitType::FuseVars: + return Halide::Serialize::SplitType::SplitType_FuseVars; + case Halide::Internal::Split::SplitType::PurifyRVar: + return Halide::Serialize::SplitType::SplitType_PurifyRVar; + default: + std::cerr << "Unsupported split type\n"; + exit(1); + } +} + flatbuffers::Offset Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) { return builder.CreateString(str); } @@ -616,11 +652,11 @@ flatbuffers::Offset Serializer::serialize_function(flat bool trace_realizations = function.is_tracing_realizations(); std::vector> trace_tags_serialized; trace_tags_serialized.reserve(function.get_trace_tags().size()); - for (const auto& tag: function.get_trace_tags()) { + for (const auto &tag : function.get_trace_tags()) { trace_tags_serialized.push_back(serialize_string(builder, tag)); } bool frozen = function.frozen(); - auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, + auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized), debug_file_serialized, extern_function_name_serialized, extern_mangling_serialized, extern_function_device_api_serialized, extern_proxy_expr_serialized.first, extern_proxy_expr_serialized.second, trace_loads, trace_stores, trace_realizations, builder.CreateVector(trace_tags_serialized), frozen); @@ -770,6 +806,16 @@ flatbuffers::Offset Serializer::serialize_ return Halide::Serialize::CreatePrefetchDirective(builder, name_serialized, at_serialized, from_serialized, offset_serialized.first, offset_serialized.second, strategy_serialized); } +flatbuffers::Offset Serializer::serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split) { + auto old_var_serialized = serialize_string(builder, split.old_var); + auto outer_serialized = serialize_string(builder, split.outer); + auto inner_serialized = serialize_string(builder, split.inner); + auto factor_serialized = serialize_expr(builder, split.factor); + auto tail_serialized = serialize_tail_strategy(split.tail); + auto inner_to_outer_serialized = serialize_split_type(split.split_type); + return Halide::Serialize::CreateSplit(builder, old_var_serialized, outer_serialized, inner_serialized, factor_serialized.first, factor_serialized.second, tail_serialized, inner_to_outer_serialized); +} + // std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) { // // instead of storing the function pointer or raw function address, // // we store a pre-computed function index as the serialized format for WrapperRef diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h index 496b895f6120..238e9920b556 100644 --- a/apps/serdes/Serializer.h +++ b/apps/serdes/Serializer.h @@ -33,6 +33,10 @@ class Serializer { Halide::Serialize::NameMangling serialize_name_mangling(const Halide::NameMangling &name_mangling); + Halide::Serialize::TailStrategy serialize_tail_strategy(const Halide::TailStrategy &tail_strategy); + + Halide::Serialize::SplitType serialize_split_type(const Halide::Internal::Split::SplitType &split_type); + flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str); flatbuffers::Offset serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type); @@ -67,6 +71,8 @@ class Serializer { flatbuffers::Offset serialize_prefetch_directive(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::PrefetchDirective &prefetch_directive); + flatbuffers::Offset serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split); + // std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers); // std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings); diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs index 4b298a4bd566..3044b6722d3e 100644 --- a/apps/serdes/halide_ir.fbs +++ b/apps/serdes/halide_ir.fbs @@ -498,6 +498,72 @@ table Specialization { failure_message: string; } +enum TailStrategy: ubyte { + RoundUp, + GuardWithIf, + PredicateLoads, + PredicateStores, + ShiftInwards, + Auto, +} + +enum SplitType: ubyte { + SplitVar, + RenameVar, + FuseVars, + PurifyRVar, +} + +table Split { + old_var: string; + outer: string; + inner: string; + factor: Expr; + tail: TailStrategy; + split_type: SplitType; +} + +// table Dim { +// var: string; +// for_type: ForType; +// device_api: DeviceAPI; +// dim_type: DimType; +// } + +// enum LoopAlignStrategy: ubyte { +// AlignStart, +// AlignEnd, +// NoAlign, +// Auto, +// } + +// table FuseLoopLevel { +// fuse_level: LoopLevel; +// align_dimension_names: [string]; +// align_strategies: [LoopAlignStrategy]; +// } + +// table FusedPair { +// func_1: string; +// func_2: string; +// stage_1: int32; +// stage_2: int32; +// var_name: string; +// } + +// table StageSchedule { +// rvars: [ReductionVariable]; +// splits: [Split]; +// dims: [Dim]; +// prefetches: [PrefetchDirective]; +// fuse_level: FuseLoopLevel; +// fused_pairs: [FusedPair]; +// touched: bool = false; +// allow_race_conditions: bool = false; +// atomic: bool = false; +// override_atomic_associativity_test: bool = false; +// } + table Definition { is_init: bool; predicate: Expr; diff --git a/src/Function.cpp b/src/Function.cpp index 15dbb8fa2ce6..225189e4c46a 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -332,11 +332,11 @@ Function::Function(const std::vector &required_types, int required_dims, c } Function::Function(const std::string &name, const std::string &origin_name, const std::vector &output_types, - const std::vector &required_types, int required_dims, const std::vector &args, - const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates, - const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling, - const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations, - const std::vector &trace_tags, bool frozen) { + const std::vector &required_types, int required_dims, const std::vector &args, + const FuncSchedule &func_schedule, const Definition &init_def, const std::vector &updates, + const std::string &debug_file, const std::string &extern_function_name, const NameMangling name_mangling, + const DeviceAPI device_api, const Expr &extern_proxy_expr, bool trace_loads, bool trace_stores, bool trace_realizations, + const std::vector &trace_tags, bool frozen) { contents.strong = new FunctionGroup; contents.strong->members.resize(1); contents->name = name;