Skip to content

Commit

Permalink
Generate single await-push command per buffer for all local chunks
Browse files Browse the repository at this point in the history
This brings await-pushes in line with pushes, where we already compute
the union of all regions required by remote chunks executed on the same
node.
  • Loading branch information
psalz committed Dec 20, 2024
1 parent 6a2b416 commit a47f1a0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
7 changes: 7 additions & 0 deletions include/grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ class region_builder {
m_boxes.push_back(box);
}

void add(const box_vector<Dims>& boxes) & {
m_boxes.reserve(m_boxes.size() + boxes.size());
for(const auto& b : boxes) {
add(b);
}
}

void add(const region<Dims>& region) & {
if(region.empty()) return;
m_normalized = m_boxes.empty();
Expand Down
29 changes: 18 additions & 11 deletions src/command_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,33 +388,40 @@ void command_graph_generator::generate_pushes(batch& current_batch, const task&
}
}

// TODO: We currently generate an await push command for each local chunk, whereas we only generate a single push command for all remote chunks
void command_graph_generator::generate_await_pushes(batch& current_batch, const task& tsk, const assigned_chunks_with_requirements& chunks_with_requirements) {
std::unordered_map<buffer_id, region_builder<3>> per_buffer_required_boxes;

for(auto& [a_chunk, requirements] : chunks_with_requirements.local_chunks) {
for(auto& [bid, consumed, _] : requirements) {
if(consumed.empty()) continue;
auto& buffer = m_buffers.at(bid);

const auto local_sources = buffer.local_last_writer.get_region_values(consumed);
region_builder<3> missing_part_boxes;
box_vector<3> missing_parts_boxes;
for(const auto& [box, wcs] : local_sources) {
// Note that we initialize all buffers as fresh, so this doesn't trigger for uninitialized reads
if(!box.empty() && !wcs.is_fresh()) { missing_part_boxes.add(box); }
if(!box.empty() && !wcs.is_fresh()) { missing_parts_boxes.push_back(box); }
}

// There is data we don't yet have locally. Generate an await push command for it.
if(!missing_part_boxes.empty()) {
const auto missing_parts = std::move(missing_part_boxes).into_region();
if(!missing_parts_boxes.empty()) {
assert(m_num_nodes > 1);
auto* const ap_cmd = create_command<await_push_command>(current_batch, transfer_id(tsk.get_id(), bid, no_reduction_id), missing_parts,
[&](const auto& record_debug_info) { record_debug_info(buffer.debug_name); });
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, missing_parts, ap_cmd);
generate_epoch_dependencies(ap_cmd);
// Remember that we have this data now
buffer.local_last_writer.update_region(missing_parts, {ap_cmd, true /* is_replicated */});
auto& required_boxes = per_buffer_required_boxes[bid]; // allow default-insert
required_boxes.add(missing_parts_boxes);
}
}
}

for(auto& [bid, boxes] : per_buffer_required_boxes) {
auto& buffer = m_buffers.at(bid);
auto region = std::move(boxes).into_region(); // moved-from after next line!
auto* const ap_cmd = create_command<await_push_command>(current_batch, transfer_id(tsk.get_id(), bid, no_reduction_id), std::move(region),
[&](const auto& record_debug_info) { record_debug_info(buffer.debug_name); });
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, ap_cmd->get_region(), ap_cmd);
generate_epoch_dependencies(ap_cmd);
// Remember that we have this data now
buffer.local_last_writer.update_region(ap_cmd->get_region(), {ap_cmd, true /* is_replicated */});
}
}

void command_graph_generator::update_local_buffer_fresh_regions(const task& tsk, const std::unordered_map<buffer_id, region<3>>& per_buffer_local_writes) {
Expand Down
24 changes: 24 additions & 0 deletions test/command_graph_transfer_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ TEST_CASE("command_graph_generator generates a single push command per buffer an
CHECK(cctx.query<push_command_record>(buf1.get_id()).on(1)[1]->target_regions == push_regions<1>({{0, region<1>{{box<1>{96, 128}}}}}));
}

TEST_CASE("command_graph_generator generates a single await_push command per buffer and task", "[command_graph_generator][command-graph]") { //
cdag_test_context cctx(2);

const range<1> test_range = {128};
auto buf0 = cctx.create_buffer(test_range);
auto buf1 = cctx.create_buffer(test_range);

// Initialize buffers across both nodes
cctx.device_compute(test_range).name("init").discard_write(buf0, acc::one_to_one{}).discard_write(buf1, acc::one_to_one{}).submit();

// Read in reverse order, but split task into 4 chunks each
cctx.set_test_chunk_multiplier(4);
cctx.device_compute(test_range).read(buf0, test_utils::access::reverse_one_to_one{}).read(buf1, test_utils::access::reverse_one_to_one{}).submit();

CHECK(cctx.query<push_command_record>().count_per_node() == 2);
CHECK(cctx.query<await_push_command_record>().count_per_node() == 2);

// The union of the required regions is just the full other half
CHECK(cctx.query<await_push_command_record>().on(0).iterate()[0]->await_region == region_cast<3>(region<1>{box<1>{64, 128}}));
CHECK(cctx.query<await_push_command_record>().on(0).iterate()[1]->await_region == region_cast<3>(region<1>{box<1>{64, 128}}));
CHECK(cctx.query<await_push_command_record>().on(1).iterate()[0]->await_region == region_cast<3>(region<1>{box<1>{0, 64}}));
CHECK(cctx.query<await_push_command_record>().on(1).iterate()[1]->await_region == region_cast<3>(region<1>{box<1>{0, 64}}));
}

TEST_CASE("command_graph_generator doesn't generate data transfer commands for the same buffer and range more than once",
"[command_graph_generator][command-graph]") {
cdag_test_context cctx(2);
Expand Down

0 comments on commit a47f1a0

Please sign in to comment.