Skip to content

Commit

Permalink
Fix: Results from single-node reductions must be await-pushed on othe…
Browse files Browse the repository at this point in the history
…r nodes
  • Loading branch information
fknorr committed Oct 25, 2023
1 parent bf03a5d commit 0b1bf48
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/distributed_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk)
const node_id reduction_initializer_nid = 0;

const box<3> empty_box({0, 0, 0}, {0, 0, 0});
const box<3> scalar_box({0, 0, 0}, {1, 1, 1});
const box<3> scalar_reduction_box({0, 0, 0}, {1, 1, 1});

// Iterate over all chunks, distinguish between local / remote chunks and normal / reduction access.
//
Expand Down Expand Up @@ -216,7 +216,7 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk)
assert(requirements[reduction.bid].count(pmode) == 0); // task_manager verifies that there are no reduction <-> write-access conflicts
}
#endif
requirements[reduction.bid][rmode] = scalar_box;
requirements[reduction.bid][rmode] = scalar_reduction_box;
}

abstract_command* cmd = nullptr;
Expand Down Expand Up @@ -356,7 +356,7 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk)
if(generate_reduction) {
const auto& reduction = *buffer_state.pending_reduction;

const auto local_last_writer = buffer_state.local_last_writer.get_region_values(scalar_box);
const auto local_last_writer = buffer_state.local_last_writer.get_region_values(scalar_reduction_box);
assert(local_last_writer.size() == 1);

if(is_local_chunk) {
Expand All @@ -367,35 +367,35 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk)
m_cdag.add_dependency(reduce_cmd, m_cdag.get(local_last_writer[0].second), dependency_kind::true_dep, dependency_origin::dataflow);
}

auto* const ap_cmd = create_command<await_push_command>(bid, reduction.rid, trid, scalar_box.get_subrange());
auto* const ap_cmd = create_command<await_push_command>(bid, reduction.rid, trid, scalar_reduction_box.get_subrange());
m_cdag.add_dependency(reduce_cmd, ap_cmd, dependency_kind::true_dep, dependency_origin::dataflow);
generate_epoch_dependencies(ap_cmd);

m_cdag.add_dependency(cmd, reduce_cmd, dependency_kind::true_dep, dependency_origin::dataflow);

// Reduction command becomes the last writer (this may be overriden if this task also writes to the reduction buffer)
post_reduction_buffer_states.at(bid).local_last_writer.update_box(scalar_box, reduce_cmd->get_cid());
post_reduction_buffer_states.at(bid).local_last_writer.update_box(scalar_reduction_box, reduce_cmd->get_cid());
} else {
// Push an empty range if we don't have any fresh data on this node
const bool notification_only = !local_last_writer[0].second.is_fresh();
const auto push_box = notification_only ? empty_box : scalar_box;
const auto push_box = notification_only ? empty_box : scalar_reduction_box;

auto* const push_cmd = create_command<push_command>(bid, reduction.rid, nid, trid, push_box.get_subrange());
generated_pushes.push_back(push_cmd);

if(notification_only) {
generate_epoch_dependencies(push_cmd);
} else {
m_command_buffer_reads[push_cmd->get_cid()][bid] = region_union(m_command_buffer_reads[push_cmd->get_cid()][bid], scalar_box);
m_command_buffer_reads[push_cmd->get_cid()][bid] = region_union(m_command_buffer_reads[push_cmd->get_cid()][bid], scalar_reduction_box);
m_cdag.add_dependency(push_cmd, m_cdag.get(local_last_writer[0].second), dependency_kind::true_dep, dependency_origin::dataflow);
}

// Mark the reduction result as replicated so we don't generate data transfers to this node
// TODO: We need a way of updating regions in place! E.g. apply_to_values(box, callback)
const auto replicated_box = post_reduction_buffer_states.at(bid).replicated_regions.get_region_values(scalar_box);
const auto replicated_box = post_reduction_buffer_states.at(bid).replicated_regions.get_region_values(scalar_reduction_box);
assert(replicated_box.size() == 1);
for(const auto& [_, nodes] : replicated_box) {
post_reduction_buffer_states.at(bid).replicated_regions.update_box(scalar_box, node_bitset{nodes}.set(nid));
post_reduction_buffer_states.at(bid).replicated_regions.update_box(scalar_reduction_box, node_bitset{nodes}.set(nid));
}
}
}
Expand Down Expand Up @@ -484,6 +484,11 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk)

// Determine which local data is fresh/stale based on task-level writes.
auto requirements = get_buffer_requirements_for_mapped_access(tsk, subrange<3>(tsk.get_global_offset(), tsk.get_global_size()), tsk.get_global_size());
// Add requirements for reductions
for(const auto& reduction : tsk.get_reductions()) {
// the actual mode is irrelevant as long as it's a producer - TODO have a better query API for task buffer requirements
requirements[reduction.bid][access_mode::write] = scalar_reduction_box;
}
for(auto& [bid, reqs_by_mode] : requirements) {
box_vector<3> global_write_boxes;
for(const auto mode : access::producer_modes) {
Expand Down
16 changes: 16 additions & 0 deletions test/graph_gen_reduction_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,19 @@ TEST_CASE("reduction commands anti-depend on their partial-result push commands"
.assert_count_per_node(1)
.have_successors(dctx.query(command_type::reduction).assert_count_per_node(1), dependency_kind::anti_dep));
}

TEST_CASE("reduction in a single-node task does not generate a reduction command, but the result is await-pushed on other nodes",
"[distributed_graph_generator][command-graph][reductions]") {
const size_t num_nodes = 3;
dist_cdag_test_context dctx(num_nodes);
auto buf = dctx.create_buffer(range<1>(1));

const auto tid_producer = dctx.device_compute(range<1>(1)).reduce(buf, false /* include_current_buffer_value */).submit();
const auto tid_consumer = dctx.device_compute(range<1>(num_nodes)).read(buf, acc::all()).submit();

CHECK(dctx.query(command_type::reduction).count() == 0);
CHECK(dctx.query(tid_producer).assert_count(1).have_successors(dctx.query(node_id(0), command_type::push).assert_count(2)));
for(node_id nid_await : {node_id(1), node_id(2)}) {
CHECK(dctx.query(nid_await, command_type::await_push).assert_count(1).have_successors(dctx.query(nid_await, tid_consumer)));
}
}

0 comments on commit 0b1bf48

Please sign in to comment.