Skip to content

Commit

Permalink
Optimization: only update last-writers twice for overlapping writes
Browse files Browse the repository at this point in the history
This also avoids an unordered_map by transposing the
perform_task_buffer_accesses loop.
  • Loading branch information
fknorr committed Dec 4, 2024
1 parent 60e9155 commit e59eacf
Showing 1 changed file with 43 additions and 45 deletions.
88 changes: 43 additions & 45 deletions src/instruction_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ class generator_impl {
instruction* launch_task_kernel(batch& command_batch, const execution_command& ecmd, const task& tsk, const localized_chunk& chunk);

/// Add dependencies for all buffer accesses and reductions of a task, then update tracking structures accordingly.
void perform_task_buffer_accesses(
const task& tsk, const std::vector<localized_chunk>& concurrent_chunks, const std::vector<instruction*>& command_instructions);
void perform_task_buffer_accesses(buffer_id bid, const execution_command& ecmd, const task& tsk, const std::vector<localized_chunk>& concurrent_chunks,
const std::vector<instruction*>& command_instructions);

/// If a task has side effects, serialize it with respect to the last task that shares a host object.
void perform_task_side_effects(
Expand Down Expand Up @@ -1867,80 +1867,76 @@ instruction* generator_impl::launch_task_kernel(batch& command_batch, const exec
}
}

void generator_impl::perform_task_buffer_accesses(
const task& tsk, const std::vector<localized_chunk>& concurrent_chunks, const std::vector<instruction*>& command_instructions) //
void generator_impl::perform_task_buffer_accesses(const buffer_id bid, const execution_command& ecmd, const task& tsk,
const std::vector<localized_chunk>& concurrent_chunks, const std::vector<instruction*>& command_instructions) //
{
CELERITY_DETAIL_TRACY_ZONE_SCOPED("iggen::perform_buffer_access", Red3);

assert(std::all_of(command_instructions.begin(), command_instructions.end(), [](const instruction* instr) { return instr != nullptr; }));

const auto& bam = tsk.get_buffer_access_map();
if(bam.get_num_accesses() == 0 && tsk.get_reductions().empty()) return;
auto& buffer = m_buffers.at(bid);

// 1. Collect the read-sets and write-sets of all concurrent chunks on all buffers (TODO this is what buffer_access_map should actually return)

struct read_write_sets {
region<3> reads;
region<3> writes;
};

std::vector<std::unordered_map<buffer_id, read_write_sets>> concurrent_read_write_sets(concurrent_chunks.size());

for(const auto bid : bam.get_accessed_buffers()) {
for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
const auto sr = concurrent_chunks[i].execution_range.get_subrange();
read_write_sets rw{bam.compute_consumed_region(bid, sr), bam.compute_produced_region(bid, sr)};
concurrent_read_write_sets[i].emplace(bid, std::move(rw));
}
std::vector<region<3>> concurrent_reads(concurrent_chunks.size());
std::vector<region<3>> concurrent_writes(concurrent_chunks.size());
for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
const auto sr = concurrent_chunks[i].execution_range.get_subrange();
concurrent_reads[i] = bam.compute_consumed_region(bid, sr);
concurrent_writes[i] = bam.compute_produced_region(bid, sr);
}

for(const auto& rinfo : tsk.get_reductions()) {
for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
auto& rw_map = concurrent_read_write_sets[i][rinfo.bid]; // allow default-insert on `bid`
rw_map.writes = region_union(rw_map.writes, scalar_reduction_box);
if(rinfo.bid != bid) continue;
for(auto& writes : concurrent_writes) {
writes = region_union(writes, scalar_reduction_box);
}
}

// 2. Insert all true-dependencies for reads and anti-dependencies for writes. We do this en-bloc instead of using `perform_concurrent_read_from_allocation`
// or `perform_atomic_write_to_allocation` to avoid incorrect dependencies between our concurrent chunks by updating tracking structures too early.

for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
for(const auto& [bid, rw] : concurrent_read_write_sets[i]) {
auto& buffer = m_buffers.at(bid);
auto& memory = buffer.memories[concurrent_chunks[i].memory_id];
auto& memory = buffer.memories[concurrent_chunks[i].memory_id];

for(auto& allocation : memory.allocations) {
add_dependencies_on_last_writers(command_instructions[i], allocation, region_intersection(rw.reads, allocation.box));
add_dependencies_on_last_concurrent_accesses(
command_instructions[i], allocation, region_intersection(rw.writes, allocation.box), instruction_dependency_origin::write_to_allocation);
}
for(auto& allocation : memory.allocations) {
add_dependencies_on_last_writers(command_instructions[i], allocation, region_intersection(concurrent_reads[i], allocation.box));
add_dependencies_on_last_concurrent_accesses(command_instructions[i], allocation, region_intersection(concurrent_writes[i], allocation.box),
instruction_dependency_origin::write_to_allocation);
}
}

// 3. Clear tracking structures for all regions that are being written to. We gracefully handle overlapping writes by treating the set of all conflicting
// writers as last writers of an allocation.
// 3. To gracefully handle overlapping writes, clear tracking structures for all regions that are being written to, so we can treat the set of all
// conflicting writers as last writers of an allocation.

for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
for(const auto& [bid, rw] : concurrent_read_write_sets[i]) {
assert(command_instructions[i] != nullptr);
auto& buffer = m_buffers.at(bid);
// Optimization: Quickly verify whether the task has overlapping writes on `bid` to avoid updating allocation last-writers twice.
// This is a simpler check than detail::detect_overlapping_writes() which does not allow diagnosing the affected region.
const size_t non_overlapping_written_elements = bam.compute_produced_region(bid, ecmd.get_execution_range()).get_area();
const size_t total_written_elements = std::accumulate(concurrent_writes.begin(), concurrent_writes.end(), size_t(0), //
[](const size_t sum, const region<3>& writes) { return sum + writes.get_area(); });
const bool has_overlapping_writes = total_written_elements > non_overlapping_written_elements;

if(has_overlapping_writes) {
for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) {
alloc.begin_concurrent_writes(region_intersection(alloc.box, rw.writes));
alloc.begin_concurrent_writes(region_intersection(alloc.box, concurrent_writes[i]));
}
}
}

// 4. Update data locations and last writers resulting from all concurrent reads and overlapping writes

for(size_t i = 0; i < concurrent_chunks.size(); ++i) {
for(const auto& [bid, rw] : concurrent_read_write_sets[i]) {
assert(command_instructions[i] != nullptr);
auto& buffer = m_buffers.at(bid);

for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) {
alloc.track_concurrent_read(region_intersection(alloc.box, rw.reads), command_instructions[i]);
alloc.track_concurrent_write(region_intersection(alloc.box, rw.writes), command_instructions[i]);
for(auto& alloc : buffer.memories[concurrent_chunks[i].memory_id].allocations) {
alloc.track_concurrent_read(region_intersection(alloc.box, concurrent_reads[i]), command_instructions[i]);
if(has_overlapping_writes) {
alloc.track_concurrent_write(region_intersection(alloc.box, concurrent_writes[i]), command_instructions[i]);
} else {
alloc.track_atomic_write(region_intersection(alloc.box, concurrent_writes[i]), command_instructions[i]);
}
buffer.track_original_write(rw.writes, command_instructions[i], concurrent_chunks[i].memory_id);
}
buffer.track_original_write(concurrent_writes[i], command_instructions[i], concurrent_chunks[i].memory_id);
}
}

Expand Down Expand Up @@ -2012,7 +2008,9 @@ void generator_impl::compile_execution_command(batch& command_batch, const execu
}

// 7. Compute dependencies and update tracking data structures
perform_task_buffer_accesses(tsk, concurrent_chunks, command_instructions);
for(const auto bid : accessed_bids) {
perform_task_buffer_accesses(bid, ecmd, tsk, concurrent_chunks, command_instructions);
}
perform_task_side_effects(tsk, concurrent_chunks, command_instructions);
perform_task_collective_operations(tsk, concurrent_chunks, command_instructions);

Expand Down

0 comments on commit e59eacf

Please sign in to comment.