From e59eacf837d70613d9a70c0c55207eb9e916f537 Mon Sep 17 00:00:00 2001 From: Fabian Knorr Date: Wed, 4 Dec 2024 11:38:36 +0100 Subject: [PATCH] Optimization: only update last-writers twice for overlapping writes This also avoids an unordered_map by transposing the perform_task_buffer_accesses loop. --- src/instruction_graph_generator.cc | 88 +++++++++++++++--------------- 1 file changed, 43 insertions(+), 45 deletions(-) diff --git a/src/instruction_graph_generator.cc b/src/instruction_graph_generator.cc index 37fcc670..b0821208 100644 --- a/src/instruction_graph_generator.cc +++ b/src/instruction_graph_generator.cc @@ -722,8 +722,8 @@ class generator_impl { instruction* launch_task_kernel(batch& command_batch, const execution_command& ecmd, const task& tsk, const localized_chunk& chunk); /// Add dependencies for all buffer accesses and reductions of a task, then update tracking structures accordingly. - void perform_task_buffer_accesses( - const task& tsk, const std::vector& concurrent_chunks, const std::vector& command_instructions); + void perform_task_buffer_accesses(buffer_id bid, const execution_command& ecmd, const task& tsk, const std::vector& concurrent_chunks, + const std::vector& command_instructions); /// If a task has side effects, serialize it with respect to the last task that shares a host object. void perform_task_side_effects( @@ -1867,35 +1867,30 @@ instruction* generator_impl::launch_task_kernel(batch& command_batch, const exec } } -void generator_impl::perform_task_buffer_accesses( - const task& tsk, const std::vector& concurrent_chunks, const std::vector& command_instructions) // +void generator_impl::perform_task_buffer_accesses(const buffer_id bid, const execution_command& ecmd, const task& tsk, + const std::vector& concurrent_chunks, const std::vector& command_instructions) // { CELERITY_DETAIL_TRACY_ZONE_SCOPED("iggen::perform_buffer_access", Red3); + assert(std::all_of(command_instructions.begin(), command_instructions.end(), [](const instruction* instr) { return instr != nullptr; })); + const auto& bam = tsk.get_buffer_access_map(); - if(bam.get_num_accesses() == 0 && tsk.get_reductions().empty()) return; + auto& buffer = m_buffers.at(bid); // 1. Collect the read-sets and write-sets of all concurrent chunks on all buffers (TODO this is what buffer_access_map should actually return) - struct read_write_sets { - region<3> reads; - region<3> writes; - }; - - std::vector> concurrent_read_write_sets(concurrent_chunks.size()); - - for(const auto bid : bam.get_accessed_buffers()) { - for(size_t i = 0; i < concurrent_chunks.size(); ++i) { - const auto sr = concurrent_chunks[i].execution_range.get_subrange(); - read_write_sets rw{bam.compute_consumed_region(bid, sr), bam.compute_produced_region(bid, sr)}; - concurrent_read_write_sets[i].emplace(bid, std::move(rw)); - } + std::vector> concurrent_reads(concurrent_chunks.size()); + std::vector> concurrent_writes(concurrent_chunks.size()); + for(size_t i = 0; i < concurrent_chunks.size(); ++i) { + const auto sr = concurrent_chunks[i].execution_range.get_subrange(); + concurrent_reads[i] = bam.compute_consumed_region(bid, sr); + concurrent_writes[i] = bam.compute_produced_region(bid, sr); } for(const auto& rinfo : tsk.get_reductions()) { - for(size_t i = 0; i < concurrent_chunks.size(); ++i) { - auto& rw_map = concurrent_read_write_sets[i][rinfo.bid]; // allow default-insert on `bid` - rw_map.writes = region_union(rw_map.writes, scalar_reduction_box); + if(rinfo.bid != bid) continue; + for(auto& writes : concurrent_writes) { + writes = region_union(writes, scalar_reduction_box); } } @@ -1903,27 +1898,29 @@ void generator_impl::perform_task_buffer_accesses( // or `perform_atomic_write_to_allocation` to avoid incorrect dependencies between our concurrent chunks by updating tracking structures too early. for(size_t i = 0; i < concurrent_chunks.size(); ++i) { - for(const auto& [bid, rw] : concurrent_read_write_sets[i]) { - auto& buffer = m_buffers.at(bid); - auto& memory = buffer.memories[concurrent_chunks[i].memory_id]; + auto& memory = buffer.memories[concurrent_chunks[i].memory_id]; - for(auto& allocation : memory.allocations) { - add_dependencies_on_last_writers(command_instructions[i], allocation, region_intersection(rw.reads, allocation.box)); - add_dependencies_on_last_concurrent_accesses( - command_instructions[i], allocation, region_intersection(rw.writes, allocation.box), instruction_dependency_origin::write_to_allocation); - } + for(auto& allocation : memory.allocations) { + add_dependencies_on_last_writers(command_instructions[i], allocation, region_intersection(concurrent_reads[i], allocation.box)); + add_dependencies_on_last_concurrent_accesses(command_instructions[i], allocation, region_intersection(concurrent_writes[i], allocation.box), + instruction_dependency_origin::write_to_allocation); } } - // 3. Clear tracking structures for all regions that are being written to. We gracefully handle overlapping writes by treating the set of all conflicting - // writers as last writers of an allocation. + // 3. To gracefully handle overlapping writes, clear tracking structures for all regions that are being written to, so we can treat the set of all + // conflicting writers as last writers of an allocation. - for(size_t i = 0; i < concurrent_chunks.size(); ++i) { - for(const auto& [bid, rw] : concurrent_read_write_sets[i]) { - assert(command_instructions[i] != nullptr); - auto& buffer = m_buffers.at(bid); + // Optimization: Quickly verify whether the task has overlapping writes on `bid` to avoid updating allocation last-writers twice. + // This is a simpler check than detail::detect_overlapping_writes() which does not allow diagnosing the affected region. + const size_t non_overlapping_written_elements = bam.compute_produced_region(bid, ecmd.get_execution_range()).get_area(); + const size_t total_written_elements = std::accumulate(concurrent_writes.begin(), concurrent_writes.end(), size_t(0), // + [](const size_t sum, const region<3>& writes) { return sum + writes.get_area(); }); + const bool has_overlapping_writes = total_written_elements > non_overlapping_written_elements; + + if(has_overlapping_writes) { + for(size_t i = 0; i < concurrent_chunks.size(); ++i) { for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) { - alloc.begin_concurrent_writes(region_intersection(alloc.box, rw.writes)); + alloc.begin_concurrent_writes(region_intersection(alloc.box, concurrent_writes[i])); } } } @@ -1931,16 +1928,15 @@ void generator_impl::perform_task_buffer_accesses( // 4. Update data locations and last writers resulting from all concurrent reads and overlapping writes for(size_t i = 0; i < concurrent_chunks.size(); ++i) { - for(const auto& [bid, rw] : concurrent_read_write_sets[i]) { - assert(command_instructions[i] != nullptr); - auto& buffer = m_buffers.at(bid); - - for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) { - alloc.track_concurrent_read(region_intersection(alloc.box, rw.reads), command_instructions[i]); - alloc.track_concurrent_write(region_intersection(alloc.box, rw.writes), command_instructions[i]); + for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) { + alloc.track_concurrent_read(region_intersection(alloc.box, concurrent_reads[i]), command_instructions[i]); + if(has_overlapping_writes) { + alloc.track_concurrent_write(region_intersection(alloc.box, concurrent_writes[i]), command_instructions[i]); + } else { + alloc.track_atomic_write(region_intersection(alloc.box, concurrent_writes[i]), command_instructions[i]); } - buffer.track_original_write(rw.writes, command_instructions[i], concurrent_chunks[i].memory_id); } + buffer.track_original_write(concurrent_writes[i], command_instructions[i], concurrent_chunks[i].memory_id); } } @@ -2012,7 +2008,9 @@ void generator_impl::compile_execution_command(batch& command_batch, const execu } // 7. Compute dependencies and update tracking data structures - perform_task_buffer_accesses(tsk, concurrent_chunks, command_instructions); + for(const auto bid : accessed_bids) { + perform_task_buffer_accesses(bid, ecmd, tsk, concurrent_chunks, command_instructions); + } perform_task_side_effects(tsk, concurrent_chunks, command_instructions); perform_task_collective_operations(tsk, concurrent_chunks, command_instructions);