Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set up graph_query testing infrastructure for task graph #325

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions include/recorders.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,58 +50,51 @@ using reduction_list = std::vector<reduction_record>;

template <typename IdType>
struct dependency_record {
const IdType node;
const dependency_kind kind;
const dependency_origin origin;
IdType predecessor;
IdType successor;
dependency_kind kind;
dependency_origin origin;

dependency_record(const IdType predecessor, const IdType successor, const dependency_kind kind, const dependency_origin origin)
: predecessor(predecessor), successor(successor), kind(kind), origin(origin) {}
};

// Task recording

using task_dependency_list = std::vector<dependency_record<task_id>>;
using task_dependency_record = dependency_record<task_id>;

// TODO: Switch to hierarchy like for CDAG/IDAG
struct task_record {
task_record(const task& tsk, const buffer_name_map& get_buffer_debug_name);

task_id tid;
task_id id;
std::string debug_name;
collective_group_id cgid;
task_type type;
task_geometry geometry;
reduction_list reductions;
access_list accesses;
detail::side_effect_map side_effect_map;
task_dependency_list dependencies;
};

class task_recorder {
public:
void record(task_record&& record) { m_recorded_tasks.push_back(std::move(record)); }
void record(std::unique_ptr<task_record> record) { m_recorded_tasks.push_back(std::move(record)); }

const std::vector<task_record>& get_tasks() const { return m_recorded_tasks; }
void record_dependency(const task_dependency_record& dependency) { m_recorded_dependencies.push_back(dependency); }

const task_record& get_task(const task_id tid) const {
const auto it = std::find_if(m_recorded_tasks.begin(), m_recorded_tasks.end(), [tid](const task_record& rec) { return rec.tid == tid; });
assert(it != m_recorded_tasks.end());
return *it;
}
const std::vector<std::unique_ptr<task_record>>& get_graph_nodes() const { return m_recorded_tasks; }

const std::vector<task_dependency_record>& get_dependencies() const { return m_recorded_dependencies; }

private:
std::vector<task_record> m_recorded_tasks;
std::vector<std::unique_ptr<task_record>> m_recorded_tasks;
std::vector<task_dependency_record> m_recorded_dependencies;
};

// Command recording

using command_dependency_list = std::vector<dependency_record<command_id>>;

struct command_dependency_record {
command_id predecessor;
command_id successor;
dependency_kind kind;
dependency_origin origin;

command_dependency_record(const command_id predecessor, const command_id successor, const dependency_kind kind, const dependency_origin origin)
: predecessor(predecessor), successor(successor), kind(kind), origin(origin) {}
};
using command_dependency_record = dependency_record<command_id>;

struct command_record : matchbox::acceptor<struct push_command_record, struct await_push_command_record, struct reduction_command_record,
struct epoch_command_record, struct horizon_command_record, struct execution_command_record, struct fence_command_record> {
Expand Down
78 changes: 42 additions & 36 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,42 @@ void format_requirements(std::string& label, const reduction_list& reductions, c
}
}

template <typename IdType>
void print_dependencies(
const std::vector<dependency_record<IdType>>& dependencies, std::string& dot,
const std::function<std::string(IdType)> id_transform = [](IdType id) { return std::to_string(id); }) {
// Sort and deduplicate edges
struct dependency_edge {
IdType predecessor;
IdType successor;
};
struct dependency_edge_order {
bool operator()(const dependency_edge& lhs, const dependency_edge& rhs) const {
if(lhs.predecessor < rhs.predecessor) return true;
if(lhs.predecessor > rhs.predecessor) return false;
return lhs.successor < rhs.successor;
}
};
struct dependency_kind_order {
bool operator()(const std::pair<dependency_kind, dependency_origin>& lhs, const std::pair<dependency_kind, dependency_origin>& rhs) const {
return (lhs.first == dependency_kind::true_dep && rhs.first != dependency_kind::true_dep);
}
};
std::map<dependency_edge, std::set<std::pair<dependency_kind, dependency_origin>, dependency_kind_order>, dependency_edge_order>
dependencies_by_edge; // ordered and unique
for(const auto& dep : dependencies) {
dependencies_by_edge[{dep.predecessor, dep.successor}].insert(std::pair{dep.kind, dep.origin});
}
for(const auto& [edge, meta] : dependencies_by_edge) {
// If there's at most two edges, take the first one (likely a true dependency followed by an anti-dependency). If there's more, bail (don't style).
const auto style = meta.size() <= 2 ? dependency_style(meta.begin()->first, meta.begin()->second) : std::string{};
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", id_transform(edge.predecessor), id_transform(edge.successor), style);
}
}

std::string get_task_label(const task_record& tsk) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.tid);
fmt::format_to(std::back_inserter(label), "T{}", tsk.id);
if(!tsk.debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\"", utils::escape_for_dot_label(tsk.debug_name)); }

fmt::format_to(std::back_inserter(label), "<br/><b>{}</b>", task_type_string(tsk.type));
Expand All @@ -107,16 +140,15 @@ std::string make_graph_preamble(const std::string& title) { return fmt::format("
std::string print_task_graph(const task_recorder& recorder, const std::string& title) {
std::string dot = make_graph_preamble(title);

CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_tasks().size());
CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_graph_nodes().size());

for(const auto& tsk : recorder.get_tasks()) {
const char* shape = tsk.type == task_type::epoch || tsk.type == task_type::horizon ? "ellipse" : "box style=rounded";
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk.tid, shape, get_task_label(tsk));
for(auto d : tsk.dependencies) {
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node, tsk.tid, dependency_style(d.kind, d.origin));
}
for(const auto& tsk : recorder.get_graph_nodes()) {
const char* shape = tsk->type == task_type::epoch || tsk->type == task_type::horizon ? "ellipse" : "box style=rounded";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ misc-const-correctness ⚠️
variable shape of type const char * can be declared const

fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->id, shape, get_task_label(*tsk));
}

print_dependencies(recorder.get_dependencies(), dot);

dot += "}";
return dot;
}
Expand All @@ -135,7 +167,7 @@ std::string print_command_graph(const node_id local_nid, const command_recorder&
std::string main_dot;
std::map<task_id, std::string> task_subgraph_dot; // this map must be ordered!

const auto local_to_global_id = [local_nid](uint64_t id) {
const auto local_to_global_id = [local_nid](auto id) -> std::string {
// IDs in the DOT language may not start with a digit (unless the whole thing is a numeral)
return fmt::format("id_{}_{}", local_nid, id);
};
Expand Down Expand Up @@ -241,33 +273,7 @@ std::string print_command_graph(const node_id local_nid, const command_recorder&
});
};

// Sort and deduplicate edges
struct dependency_edge {
command_id predecessor;
command_id successor;
};
struct dependency_edge_order {
bool operator()(const dependency_edge& lhs, const dependency_edge& rhs) const {
if(lhs.predecessor < rhs.predecessor) return true;
if(lhs.predecessor > rhs.predecessor) return false;
return lhs.successor < rhs.successor;
}
};
struct dependency_kind_order {
bool operator()(const std::pair<dependency_kind, dependency_origin>& lhs, const std::pair<dependency_kind, dependency_origin>& rhs) const {
return (lhs.first == dependency_kind::true_dep && rhs.first != dependency_kind::true_dep);
}
};
std::map<dependency_edge, std::set<std::pair<dependency_kind, dependency_origin>, dependency_kind_order>, dependency_edge_order>
dependencies_by_edge; // ordered and unique
for(const auto& dep : recorder.get_dependencies()) {
dependencies_by_edge[{dep.predecessor, dep.successor}].insert(std::pair{dep.kind, dep.origin});
}
for(const auto& [edge, meta] : dependencies_by_edge) {
// If there's at most two edges, take the first one (likely a true dependency followed by an anti-dependency). If there's more, bail (don't style).
const auto style = meta.size() <= 2 ? dependency_style(meta.begin()->first, meta.begin()->second) : std::string{};
fmt::format_to(std::back_inserter(main_dot), "{}->{}[{}];", local_to_global_id(edge.predecessor), local_to_global_id(edge.successor), style);
}
print_dependencies<command_id>(recorder.get_dependencies(), main_dot, local_to_global_id);

std::string result_dot = make_graph_preamble(title);
for(auto& [_, sg_dot] : task_subgraph_dot) {
Expand Down
12 changes: 2 additions & 10 deletions src/recorders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,10 @@ reduction_list build_reduction_list(const task& tsk, const buffer_name_map& get_
return ret;
}

task_dependency_list build_task_dependency_list(const task& tsk) {
task_dependency_list ret;
for(const auto& dep : tsk.get_dependencies()) {
ret.push_back({dep.node->get_id(), dep.kind, dep.origin});
}
return ret;
}

task_record::task_record(const task& tsk, const buffer_name_map& get_buffer_debug_name)
: tid(tsk.get_id()), debug_name(tsk.get_debug_name()), cgid(tsk.get_collective_group_id()), type(tsk.get_type()), geometry(tsk.get_geometry()),
: id(tsk.get_id()), debug_name(tsk.get_debug_name()), cgid(tsk.get_collective_group_id()), type(tsk.get_type()), geometry(tsk.get_geometry()),
reductions(build_reduction_list(tsk, get_buffer_debug_name)), accesses(build_access_list(tsk, get_buffer_debug_name)),
side_effect_map(tsk.get_side_effect_map()), dependencies(build_task_dependency_list(tsk)) {}
side_effect_map(tsk.get_side_effect_map()) {}

// Commands

Expand Down
3 changes: 2 additions & 1 deletion src/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ namespace detail {
void task_manager::invoke_callbacks(const task* tsk) const {
if(m_delegate != nullptr) { m_delegate->task_created(tsk); }
if(m_task_recorder != nullptr) {
m_task_recorder->record(task_record(*tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }));
m_task_recorder->record(std::make_unique<task_record>(*tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }));
}
}

Expand All @@ -187,6 +187,7 @@ namespace detail {
depender.add_dependency({&dependee, kind, origin});
m_execution_front.erase(&dependee);
m_max_pseudo_critical_path_length = std::max(m_max_pseudo_critical_path_length, depender.get_pseudo_critical_path_length());
if(m_task_recorder != nullptr) { m_task_recorder->record_dependency({dependee.get_id(), depender.get_id(), kind, origin}); }
}

bool task_manager::need_new_horizon() const {
Expand Down
38 changes: 10 additions & 28 deletions test/accessor_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <celerity.h>

#include "task_graph_test_utils.h"
#include "test_utils.h"

namespace celerity {
Expand Down Expand Up @@ -230,34 +231,15 @@ namespace detail {
}

TEST_CASE("conflicts between producer-accessors and reductions are reported", "[task-manager]") {
test_utils::task_test_context tt;

auto buf_0 = tt.mbf.create_buffer(range<1>{1});

CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_conflict)>(tt.tm, [&](handler& cgh) {
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
}));

CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
buf_0.get_access<access_mode::read>(cgh, fixed<1>({0, 1}));
}));

CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
buf_0.get_access<access_mode::write>(cgh, fixed<1>({0, 1}));
}));

CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
buf_0.get_access<access_mode::read_write>(cgh, fixed<1>({0, 1}));
}));

CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
buf_0.get_access<access_mode::discard_write>(cgh, fixed<1>({0, 1}));
}));
test_utils::tdag_test_context tctx(1 /* num_collective_nodes */);

auto buf_0 = tctx.create_buffer(range<1>{1});

CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).reduce(buf_0, false).submit());
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).read(buf_0, all{}).submit());
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).write(buf_0, all{}).submit());
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).read_write(buf_0, all{}).submit());
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).discard_write(buf_0, all{}).submit());
}

template <access_mode>
Expand Down
31 changes: 14 additions & 17 deletions test/debug_naming_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <celerity.h>

#include "task_graph_test_utils.h"
#include "test_utils.h"

using namespace celerity;
Expand All @@ -15,34 +16,30 @@ using namespace celerity::detail;
TEST_CASE("debug names can be set and retrieved from tasks", "[debug]") {
const std::string task_name = "sample task";

auto tt = test_utils::task_test_context{};
test_utils::tdag_test_context tctx(1 /* num_collective_nodes */);

SECTION("Host Task") {
const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });
const auto tid_a = tctx.master_node_host_task().name(task_name).submit();
const auto tid_b = tctx.master_node_host_task().submit();

const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {});

CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
CHECK(test_utils::get_task(tt.tdag, tid_b)->get_debug_name().empty());
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name().empty());
}

SECTION("Compute Task") {
const auto tid_a = test_utils::add_compute_task<class compute_task>(tt.tm, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });

const auto tid_b = test_utils::add_compute_task<class compute_task_unnamed>(tt.tm, [&](handler& cgh) {});
const auto tid_a = tctx.device_compute(range<1>(ones)).name(task_name).submit();
const auto tid_b = tctx.device_compute<class compute_task_unnamed>(range<1>(ones)).submit();

CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
CHECK_THAT(test_utils::get_task(tt.tdag, tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("compute_task_unnamed"));
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
CHECK_THAT(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("compute_task_unnamed"));
}

SECTION("ND Range Task") {
const auto tid_a =
test_utils::add_nd_range_compute_task<class nd_range_task>(tt.tm, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });

const auto tid_b = test_utils::add_compute_task<class nd_range_task_unnamed>(tt.tm, [&](handler& cgh) {});
const auto tid_a = tctx.device_compute(nd_range<1>{range<1>{1}, range<1>{1}}).name(task_name).submit();
const auto tid_b = tctx.device_compute<class nd_range_task_unnamed>(nd_range<1>{range<1>{1}, range<1>{1}}).submit();

CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
CHECK_THAT(test_utils::get_task(tt.tdag, tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("nd_range_task_unnamed"));
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
CHECK_THAT(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("nd_range_task_unnamed"));
}
}

Expand Down
Loading
Loading