Skip to content

Commit

Permalink
prefetch directive
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent 63ac57a commit ca4ac94
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 5 deletions.
34 changes: 31 additions & 3 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ Halide::Internal::VectorReduce::Operator Deserializer::deserialize_vector_reduce
}
}

Halide::PrefetchBoundStrategy Deserializer::deserialize_prefetch_bound_strategy(const Halide::Serialize::PrefetchBoundStrategy prefetch_bound_strategy) {
switch (prefetch_bound_strategy) {
case Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_Clamp:
return Halide::PrefetchBoundStrategy::Clamp;
case Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_GuardWithIf:
return Halide::PrefetchBoundStrategy::GuardWithIf;
case Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_NonFaulting:
return Halide::PrefetchBoundStrategy::NonFaulting;
default:
std::cerr << "unknown prefetch bound strategy " << prefetch_bound_strategy << "\n";
exit(1);
}
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
Expand Down Expand Up @@ -307,11 +321,10 @@ Halide::Internal::Stmt Deserializer::deserialize_stmt(uint8_t type_code, const v
for (const auto &bound : *prefetch_stmt->bounds()) {
bounds.push_back(deserialize_range(bound));
}
auto prefetch = deserialize_prefetch_directive(prefetch_stmt->prefetch());
auto condition = deserialize_expr(prefetch_stmt->condition_type(), prefetch_stmt->condition());
auto body = deserialize_stmt(prefetch_stmt->body_type(), prefetch_stmt->body());
return Halide::Internal::Prefetch::make(name, types, bounds,
Halide::Internal::PrefetchDirective(),
condition, body);
return Halide::Internal::Prefetch::make(name, types, bounds, prefetch, condition, body);
}
case Halide::Serialize::Stmt_Acquire: {
const Halide::Serialize::Acquire *acquire_stmt = (const Halide::Serialize::Acquire *)stmt;
Expand Down Expand Up @@ -685,6 +698,21 @@ Halide::Internal::ModulusRemainder Deserializer::deserialize_modulus_remainder(c
return Halide::Internal::ModulusRemainder(modulus_remainder->modulus(), modulus_remainder->remainder());
}

Halide::Internal::PrefetchDirective Deserializer::deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive) {
auto name = deserialize_string(prefetch_directive->name());
auto at = deserialize_string(prefetch_directive->at());
auto from = deserialize_string(prefetch_directive->from());
auto offset = deserialize_expr(prefetch_directive->offset_type(), prefetch_directive->offset());
auto strategy = deserialize_prefetch_bound_strategy(prefetch_directive->strategy());
auto hl_prefetch_directive = Halide::Internal::PrefetchDirective();
hl_prefetch_directive.name = name;
hl_prefetch_directive.at = at;
hl_prefetch_directive.from = from;
hl_prefetch_directive.offset = offset;
hl_prefetch_directive.strategy = strategy;
return hl_prefetch_directive;
}

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
// std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
Expand Down
4 changes: 4 additions & 0 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class Deserializer {

Halide::Internal::VectorReduce::Operator deserialize_vector_reduce_op(const Halide::Serialize::VectorReduceOp vector_reduce_op);

Halide::PrefetchBoundStrategy deserialize_prefetch_bound_strategy(const Halide::Serialize::PrefetchBoundStrategy prefetch_bound_strategy);

std::string deserialize_string(const flatbuffers::String *str);

Halide::Type deserialize_type(const Halide::Serialize::Type *type);
Expand Down Expand Up @@ -60,6 +62,8 @@ class Deserializer {

Halide::Internal::ModulusRemainder deserialize_modulus_remainder(const Halide::Serialize::ModulusRemainder *modulus_remainder);

Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive);

// std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

// std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
Expand Down
26 changes: 25 additions & 1 deletion apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ Halide::Serialize::VectorReduceOp Serializer::serialize_vector_reduce_op(const H
}
}

Halide::Serialize::PrefetchBoundStrategy Serializer::serialize_prefetch_bound_strategy(const Halide::PrefetchBoundStrategy &prefetch_bound_strategy) {
switch (prefetch_bound_strategy) {
case Halide::PrefetchBoundStrategy::Clamp:
return Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_Clamp;
case Halide::PrefetchBoundStrategy::GuardWithIf:
return Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_GuardWithIf;
case Halide::PrefetchBoundStrategy::NonFaulting:
return Halide::Serialize::PrefetchBoundStrategy::PrefetchBoundStrategy_NonFaulting;
default:
std::cerr << "Unsupported prefetch bound strategy\n";
exit(1);
}
}

flatbuffers::Offset<flatbuffers::String> Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) {
return builder.CreateString(str);
}
Expand Down Expand Up @@ -293,9 +307,10 @@ std::pair<Halide::Serialize::Stmt, flatbuffers::Offset<void>> Serializer::serial
for (const auto &bound : bounds) {
bounds_serialized.push_back(serialize_range(builder, bound));
}
auto prefetch_serialized = serialize_prefetch_directive(builder, prefetch_stmt->prefetch);
auto condition_serialized = serialize_expr(builder, prefetch_stmt->condition);
auto body_serialized = serialize_stmt(builder, prefetch_stmt->body);
return std::make_pair(Halide::Serialize::Stmt::Stmt_Prefetch, Halide::Serialize::CreatePrefetch(builder, name_serialized, types_vector, builder.CreateVector(bounds_serialized), condition_serialized.first, condition_serialized.second, body_serialized.first, body_serialized.second).Union());
return std::make_pair(Halide::Serialize::Stmt::Stmt_Prefetch, Halide::Serialize::CreatePrefetch(builder, name_serialized, types_vector, builder.CreateVector(bounds_serialized), prefetch_serialized, condition_serialized.first, condition_serialized.second, body_serialized.first, body_serialized.second).Union());
}
case Halide::Internal::IRNodeType::Acquire: {
auto acquire_stmt = stmt.as<Halide::Internal::Acquire>();
Expand Down Expand Up @@ -714,6 +729,15 @@ flatbuffers::Offset<Halide::Serialize::ModulusRemainder> Serializer::serialize_m
return Halide::Serialize::CreateModulusRemainder(builder, modulus_remainder.modulus, modulus_remainder.remainder);
}

flatbuffers::Offset<Halide::Serialize::PrefetchDirective> Serializer::serialize_prefetch_directive(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::PrefetchDirective &prefetch_directive) {
auto name_serialized = serialize_string(builder, prefetch_directive.name);
auto at_serialized = serialize_string(builder, prefetch_directive.at);
auto from_serialized = serialize_string(builder, prefetch_directive.from);
auto offset_serialized = serialize_expr(builder, prefetch_directive.offset);
auto strategy_serialized = serialize_prefetch_bound_strategy(prefetch_directive.strategy);
return Halide::Serialize::CreatePrefetchDirective(builder, name_serialized, at_serialized, from_serialized, offset_serialized.first, offset_serialized.second, strategy_serialized);
}

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
// // instead of storing the function pointer or raw function address,
// // we store a pre-computed function index as the serialized format for WrapperRef
Expand Down
4 changes: 4 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Serializer {

Halide::Serialize::VectorReduceOp serialize_vector_reduce_op(const Halide::Internal::VectorReduce::Operator &vector_reduce_op);

Halide::Serialize::PrefetchBoundStrategy serialize_prefetch_bound_strategy(const Halide::PrefetchBoundStrategy &prefetch_bound_strategy);

flatbuffers::Offset<flatbuffers::String> serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);

flatbuffers::Offset<Halide::Serialize::Type> serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type);
Expand Down Expand Up @@ -61,6 +63,8 @@ class Serializer {

flatbuffers::Offset<Halide::Serialize::ModulusRemainder> serialize_modulus_remainder(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::ModulusRemainder &modulus_remainder);

flatbuffers::Offset<Halide::Serialize::PrefetchDirective> serialize_prefetch_directive(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::PrefetchDirective &prefetch_directive);

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);

// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
Expand Down
17 changes: 16 additions & 1 deletion apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,26 @@ table Evaluate {
value: Expr;
}

enum PrefetchBoundStrategy: ubyte {
Clamp,
GuardWithIf,
NonFaulting,
}

table PrefetchDirective {
name: string;
at: string;
from: string;
offset: Expr;
strategy: PrefetchBoundStrategy;
// param: Parameter;
}

table Prefetch {
name: string;
types: [Type];
bounds: [Range];
// prefetch: PrefetchDirective;
prefetch: PrefetchDirective;
condition: Expr;
body: Stmt;
}
Expand Down

0 comments on commit ca4ac94

Please sign in to comment.