Skip to content

Commit

Permalink
Specialization, Definition
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 7, 2023
1 parent fe79618 commit 478970e
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 32 deletions.
77 changes: 55 additions & 22 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@ std::string Deserializer::deserialize_string(const flatbuffers::String *str) {

Halide::MemoryType Deserializer::deserialize_memory_type(const Halide::Serialize::MemoryType memory_type) {
switch (memory_type) {
case Halide::Serialize::MemoryType::MemoryType_Auto:
return Halide::MemoryType::Auto;
case Halide::Serialize::MemoryType::MemoryType_Heap:
return Halide::MemoryType::Heap;
case Halide::Serialize::MemoryType::MemoryType_Stack:
return Halide::MemoryType::Stack;
case Halide::Serialize::MemoryType::MemoryType_Register:
return Halide::MemoryType::Register;
case Halide::Serialize::MemoryType::MemoryType_GPUShared:
return Halide::MemoryType::GPUShared;
case Halide::Serialize::MemoryType::MemoryType_GPUTexture:
return Halide::MemoryType::GPUTexture;
case Halide::Serialize::MemoryType::MemoryType_LockedCache:
return Halide::MemoryType::LockedCache;
case Halide::Serialize::MemoryType::MemoryType_VTCM:
return Halide::MemoryType::VTCM;
case Halide::Serialize::MemoryType::MemoryType_AMXTile:
return Halide::MemoryType::AMXTile;
default:
std::cerr << "unknown memory type " << memory_type << "\n";
return Halide::MemoryType::Auto;
case Halide::Serialize::MemoryType::MemoryType_Auto:
return Halide::MemoryType::Auto;
case Halide::Serialize::MemoryType::MemoryType_Heap:
return Halide::MemoryType::Heap;
case Halide::Serialize::MemoryType::MemoryType_Stack:
return Halide::MemoryType::Stack;
case Halide::Serialize::MemoryType::MemoryType_Register:
return Halide::MemoryType::Register;
case Halide::Serialize::MemoryType::MemoryType_GPUShared:
return Halide::MemoryType::GPUShared;
case Halide::Serialize::MemoryType::MemoryType_GPUTexture:
return Halide::MemoryType::GPUTexture;
case Halide::Serialize::MemoryType::MemoryType_LockedCache:
return Halide::MemoryType::LockedCache;
case Halide::Serialize::MemoryType::MemoryType_VTCM:
return Halide::MemoryType::VTCM;
case Halide::Serialize::MemoryType::MemoryType_AMXTile:
return Halide::MemoryType::AMXTile;
default:
std::cerr << "unknown memory type " << memory_type << "\n";
return Halide::MemoryType::Auto;
}
}

Expand Down Expand Up @@ -83,7 +83,15 @@ Halide::Internal::Function Deserializer::deserialize_function(const Halide::Seri

auto func_schedule = deserialize_func_schedule(function->func_schedule());

return Halide::Internal::Function(name, origin_name, output_types, required_types, required_dim, args, func_schedule);
auto init_def = deserialize_definition(function->init_def());

std::vector<Halide::Internal::Definition> updates;
for (const auto &update : *function->updates()) {
updates.push_back(deserialize_definition(update));
}

return Halide::Internal::Function(name, origin_name, output_types, required_types,
required_dim, args, func_schedule, init_def, updates);
}

Halide::Internal::Stmt Deserializer::deserialize_stmt(uint8_t type_code, const void *stmt) {
Expand Down Expand Up @@ -589,6 +597,31 @@ Halide::Internal::FuncSchedule Deserializer::deserialize_func_schedule(const Hal
return hl_func_schedule;
}

Halide::Internal::Specialization Deserializer::deserialize_specialization(const Halide::Serialize::Specialization *specialization) {
auto condition = deserialize_expr(specialization->condition_type(), specialization->condition());
auto defintion = deserialize_definition(specialization->definition());
auto failure_message = deserialize_string(specialization->failure_message());
Halide::Internal::Specialization hl_specialization;
hl_specialization.condition = condition;
hl_specialization.definition = defintion;
hl_specialization.failure_message = failure_message;
return hl_specialization;
}

Halide::Internal::Definition Deserializer::deserialize_definition(const Halide::Serialize::Definition *definition) {
auto is_init = definition->is_init();
auto predicate = deserialize_expr(definition->predicate_type(), definition->predicate());
auto args = deserialize_expr_vector(definition->args_type(), definition->args());
auto values = deserialize_expr_vector(definition->values_type(), definition->values());

std::vector<Halide::Internal::Specialization> specializations;
for (const auto &specialization : *definition->specializations()) {
specializations.push_back(deserialize_specialization(specialization));
}
auto source_location = deserialize_string(definition->source_location());
return Halide::Internal::Definition(is_init, predicate, args, values, Halide::Internal::StageSchedule(), specializations, source_location);
}

Halide::Pipeline Deserializer::deserialize(const std::string &filename) {
// unpack binary file
std::ifstream in(filename, std::ios::binary | std::ios::in);
Expand Down
6 changes: 5 additions & 1 deletion apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class Deserializer {

private:
// helper functions to deserialize each type of object
Halide::MemoryType deserialize_memory_type(const Halide::Serialize::MemoryType memory_type);

std::string deserialize_string(const flatbuffers::String *str);

Halide::Type deserialize_type(const Halide::Serialize::Type *type);
Expand All @@ -37,7 +39,9 @@ class Deserializer {

Halide::Internal::FuncSchedule deserialize_func_schedule(const Halide::Serialize::FuncSchedule *func_schedule);

Halide::MemoryType deserialize_memory_type(const Halide::Serialize::MemoryType memory_type);
Halide::Internal::Specialization deserialize_specialization(const Halide::Serialize::Specialization *specialization);

Halide::Internal::Definition deserialize_definition(const Halide::Serialize::Definition *definition);
};

#endif
51 changes: 46 additions & 5 deletions apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,13 @@ flatbuffers::Offset<Halide::Serialize::Func> Serializer::serialize_function(flat
}
auto args_vector = builder.CreateVector(args_serialized);

auto func_schedule = function.schedule();
auto func_schedule_serialized = serialize_func_schedule(builder, func_schedule);

auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, args_vector, func_schedule_serialized);
auto func_schedule_serialized = serialize_func_schedule(builder, function.schedule());
auto init_def_serialized = serialize_definition(builder, function.definition());
std::vector<flatbuffers::Offset<Halide::Serialize::Definition>> updates_serialized;
for (const auto &update : function.updates()) {
updates_serialized.push_back(serialize_definition(builder, update));
}
auto func = Halide::Serialize::CreateFunc(builder, name_serialized, origin_name_serialized, output_types_vector, required_types_vector, required_dim, args_vector, func_schedule_serialized, init_def_serialized, builder.CreateVector(updates_serialized));
return func;
}

Expand Down Expand Up @@ -818,7 +821,45 @@ flatbuffers::Offset<Halide::Serialize::FuncSchedule> Serializer::serialize_func_
auto async = func_schedule.async();
auto memoize_eviction_key = func_schedule.memoize_eviction_key();
auto memoize_eviction_key_serialized = serialize_expr(builder, memoize_eviction_key);
return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized), builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
return Halide::Serialize::CreateFuncSchedule(builder, store_level_serialized, compute_level_serialized, builder.CreateVector(storage_dims_serialized), builder.CreateVector(bounds_serialized),
builder.CreateVector(estimates_serialized), memory_type, memoized, async, memoize_eviction_key_serialized.first, memoize_eviction_key_serialized.second);
}

flatbuffers::Offset<Halide::Serialize::Specialization> Serializer::serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization) {
auto condition_serialized = serialize_expr(builder, specialization.condition);
auto definition_serialized = serialize_definition(builder, specialization.definition);
auto failure_message_serialized = serialize_string(builder, specialization.failure_message);
return Halide::Serialize::CreateSpecialization(builder, condition_serialized.first, condition_serialized.second, definition_serialized, failure_message_serialized);
}

flatbuffers::Offset<Halide::Serialize::Definition> Serializer::serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition) {
auto is_init = definition.is_init();
auto predicate_serialized = serialize_expr(builder, definition.predicate());
std::vector<uint8_t> values_types;
values_types.reserve(definition.values().size());
std::vector<flatbuffers::Offset<void>> values_serialized;
values_serialized.reserve(definition.values().size());
for (const auto &value : definition.values()) {
auto value_serialized = serialize_expr(builder, value);
values_types.push_back(value_serialized.first);
values_serialized.push_back(value_serialized.second);
}
std::vector<uint8_t> args_types;
args_types.reserve(definition.args().size());
std::vector<flatbuffers::Offset<void>> args_serialized;
args_serialized.reserve(definition.args().size());
for (const auto &arg : definition.args()) {
auto arg_serialized = serialize_expr(builder, arg);
args_types.push_back(arg_serialized.first);
args_serialized.push_back(arg_serialized.second);
}
std::vector<flatbuffers::Offset<Halide::Serialize::Specialization>> specializations_serialized;
for (const auto &specialization : definition.specializations()) {
specializations_serialized.push_back(serialize_specialization(builder, specialization));
}
auto source_location_serialized = serialize_string(builder, definition.source_location());
return Halide::Serialize::CreateDefinition(builder, is_init, predicate_serialized.first, predicate_serialized.second, builder.CreateVector(values_types),
builder.CreateVector(values_serialized), builder.CreateVector(args_types), builder.CreateVector(args_serialized), builder.CreateVector(specializations_serialized), source_location_serialized);
}

void Serializer::serialize(const Halide::Pipeline &pipeline, const std::string &filename) {
Expand Down
4 changes: 4 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class Serializer {
flatbuffers::Offset<Halide::Serialize::LoopLevel> serialize_loop_level(flatbuffers::FlatBufferBuilder &builder, const Halide::LoopLevel &loop_level);

flatbuffers::Offset<Halide::Serialize::FuncSchedule> serialize_func_schedule(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::FuncSchedule &func_schedule);

flatbuffers::Offset<Halide::Serialize::Specialization> serialize_specialization(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Specialization &specialization);

flatbuffers::Offset<Halide::Serialize::Definition> serialize_definition(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Definition &definition);
};

#endif
20 changes: 18 additions & 2 deletions apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,22 @@ table FuncSchedule {
memoize_eviction_key: Expr;
}

table Specialization {
condition: Expr;
definition: Definition;
failure_message: string;
}

table Definition {
is_init: bool;
predicate: Expr;
values: [Expr];
args: [Expr];
// stage_schedule: StageSchedule;
specializations: [Specialization];
source_location: string;
}

// Halide::internal::Function
table Func {
name: string;
Expand All @@ -444,8 +460,8 @@ table Func {
required_dims: int32;
args: [string];
func_schedule: FuncSchedule;
// init_def: Definition;
// updates: [Definition];
init_def: Definition;
updates: [Definition];
// debug_file: string;
// output_buffers: [Parameter];
// extern_arguments: [ExternFuncArgument];
Expand Down
12 changes: 12 additions & 0 deletions src/Definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ Definition::Definition(const std::vector<Expr> &args, const std::vector<Expr> &v
}
}

Definition::Definition(bool is_init, const Expr &predicate, const std::vector<Expr> &args, const std::vector<Expr> &values,
const StageSchedule &schedule, const std::vector<Specialization> &specializations, const std::string &source_location)
: contents(new DefinitionContents) {
contents->is_init = is_init;
contents->values = values;
contents->args = args;
contents->predicate = predicate;
contents->stage_schedule = schedule;
contents->specializations = specializations;
contents->source_location = source_location;
}

Definition Definition::get_copy() const {
internal_assert(contents.defined()) << "Cannot copy undefined Definition\n";

Expand Down
4 changes: 4 additions & 0 deletions src/Definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class Definition {
Definition(const std::vector<Expr> &args, const std::vector<Expr> &values,
const ReductionDomain &rdom, bool is_init);

/** Construct a Definition with deserialized data. */
Definition(bool is_init, const Expr &predicate, const std::vector<Expr> &args, const std::vector<Expr> &values,
const StageSchedule &schedule, const std::vector<Specialization> &specializations, const std::string &source_location);

/** Construct an undefined Definition object. */
Definition();

Expand Down
4 changes: 3 additions & 1 deletion src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ Function::Function(const std::vector<Type> &required_types, int required_dims, c

Function::Function(const std::string &name, const std::string &origin_name, const std::vector<Halide::Type> &output_types,
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule) {
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates) {
contents.strong = new FunctionGroup;
contents.strong->members.resize(1);
contents->name = name;
Expand All @@ -343,6 +343,8 @@ Function::Function(const std::string &name, const std::string &origin_name, cons
contents->required_dims = required_dims;
contents->args = args;
contents->func_schedule = func_schedule;
contents->init_def = init_def;
contents->updates = updates;
}

namespace {
Expand Down
2 changes: 1 addition & 1 deletion src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Function {
/** Construct a function from deserializing */
explicit Function(const std::string &name, const std::string &origin_name, const std::vector<Halide::Type> &output_types,
const std::vector<Halide::Type> &required_types, int required_dims, const std::vector<std::string> &args,
const FuncSchedule &func_schedule);
const FuncSchedule &func_schedule, const Definition &init_def, const std::vector<Definition> &updates);

/** Get a handle on the halide function contents that this Function
* represents. */
Expand Down

0 comments on commit 478970e

Please sign in to comment.