Skip to content

refactor: Per ingredient sync table #650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ pub(crate) use maybe_changed_after::VerifyResult;
use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues};
use crate::cycle::{CycleHeadKind, CycleRecoveryAction, CycleRecoveryStrategy};
use crate::function::delete::DeletedEntries;
use crate::function::sync::{ClaimResult, SyncTable};
use crate::ingredient::{fmt_index, Ingredient};
use crate::key::DatabaseKeyIndex;
use crate::plumbing::MemoIngredientMap;
use crate::salsa_struct::SalsaStructInDb;
use crate::table::memo::MemoTableTypes;
use crate::table::sync::ClaimResult;
use crate::table::Table;
use crate::views::DatabaseDownCaster;
use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa};
Expand All @@ -31,6 +31,7 @@ mod lru;
mod maybe_changed_after;
mod memo;
mod specify;
mod sync;

pub type Memo<C> = memo::Memo<<C as Configuration>::Output<'static>>;

Expand Down Expand Up @@ -120,6 +121,8 @@ pub struct IngredientImpl<C: Configuration> {
/// instances that this downcaster was derived from.
view_caster: DatabaseDownCaster<C::DbView>,

sync_table: SyncTable,

/// When `fetch` and friends executes, they return a reference to the
/// value stored in the memo that is extended to live as long as the `&self`
/// reference we start with. This means that whenever we remove something
Expand Down Expand Up @@ -161,6 +164,7 @@ where
lru: lru::Lru::new(lru),
deleted_entries: Default::default(),
view_caster,
sync_table: SyncTable::new(index),
}
}

Expand Down Expand Up @@ -269,12 +273,7 @@ where
/// Attempts to claim `key_index`, returning `false` if a cycle occurs.
fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool {
let zalsa = db.zalsa();
match zalsa.sync_table_for(key_index).claim(
db,
zalsa,
self.database_key_index(key_index),
self.memo_ingredient_index(zalsa, key_index),
) {
match self.sync_table.try_claim(db, zalsa, key_index) {
ClaimResult::Retry | ClaimResult::Claimed(_) => true,
ClaimResult::Cycle => false,
}
Expand Down
36 changes: 12 additions & 24 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::cycle::{CycleHeads, CycleRecoveryStrategy};
use crate::function::memo::Memo;
use crate::function::sync::ClaimResult;
use crate::function::{Configuration, IngredientImpl, VerifyResult};
use crate::table::sync::ClaimResult;
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
use crate::zalsa_local::QueryRevisions;
use crate::Id;
Expand Down Expand Up @@ -96,19 +96,11 @@ where
id: Id,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<C::Output<'db>>> {
let database_key_index = self.database_key_index(id);

// Try to claim this query: if someone else has claimed it already, go back and start again.
let _claim_guard = match zalsa.sync_table_for(id).claim(
db,
zalsa,
database_key_index,
memo_ingredient_index,
) {
ClaimResult::Retry => {
return None;
}
let _claim_guard = match self.sync_table.try_claim(db, zalsa, id) {
ClaimResult::Retry => return None,
ClaimResult::Cycle => {
let database_key_index = self.database_key_index(id);
// check if there's a provisional value for this query
// Note we don't `validate_may_be_provisional` the memo here as we want to reuse an
// existing provisional memo if it exists
Expand Down Expand Up @@ -151,12 +143,10 @@ where
database_key_index,
zalsa.current_revision(),
);
let initial_value = self
.initial_value(db, database_key_index.key_index())
.expect(
"`CycleRecoveryStrategy::Fixpoint` \
let initial_value = self.initial_value(db, id).expect(
"`CycleRecoveryStrategy::Fixpoint` \
should have initial_value",
);
);
Some(self.insert_memo(
zalsa,
id,
Expand All @@ -169,12 +159,10 @@ where
"hit a `FallbackImmediate` cycle at {database_key_index:#?}"
);
let active_query = db.zalsa_local().push_query(database_key_index, 0);
let fallback_value = self
.initial_value(db, database_key_index.key_index())
.expect(
"`CycleRecoveryStrategy::FallbackImmediate` \
let fallback_value = self.initial_value(db, id).expect(
"`CycleRecoveryStrategy::FallbackImmediate` \
should have initial_value",
);
);
let mut revisions = active_query.pop();
revisions.cycle_heads = CycleHeads::initial(database_key_index);
// We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`.
Expand All @@ -196,7 +184,7 @@ where
if let Some(old_memo) = opt_old_memo {
if old_memo.value.is_some() {
if let VerifyResult::Unchanged(_, cycle_heads) =
self.deep_verify_memo(db, zalsa, old_memo, database_key_index)
self.deep_verify_memo(db, zalsa, old_memo, self.database_key_index(id))
{
if cycle_heads.is_empty() {
// SAFETY: memo is present in memo_map and we have verified that it is
Expand All @@ -209,7 +197,7 @@ where

let memo = self.execute(
db,
db.zalsa_local().push_query(database_key_index, 0),
db.zalsa_local().push_query(self.database_key_index(id), 0),
opt_old_memo,
);

Expand Down
9 changes: 2 additions & 7 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::sync::atomic::Ordering;
use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy};
use crate::function::memo::Memo;
use crate::function::sync::ClaimResult;
use crate::function::{Configuration, IngredientImpl};
use crate::key::DatabaseKeyIndex;
use crate::table::sync::ClaimResult;
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
use crate::zalsa_local::{QueryEdge, QueryOrigin};
use crate::{AsDynDatabase as _, Id, Revision};
Expand Down Expand Up @@ -102,12 +102,7 @@ where
) -> Option<VerifyResult> {
let database_key_index = self.database_key_index(key_index);

let _claim_guard = match zalsa.sync_table_for(key_index).claim(
db,
zalsa,
database_key_index,
memo_ingredient_index,
) {
let _claim_guard = match self.sync_table.try_claim(db, zalsa, key_index) {
ClaimResult::Retry => return None,
ClaimResult::Cycle => match C::CYCLE_STRATEGY {
CycleRecoveryStrategy::Panic => db.zalsa_local().with_query_stack(|stack| {
Expand Down
101 changes: 52 additions & 49 deletions src/table/sync.rs → src/function/sync.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
use std::thread::ThreadId;

use parking_lot::Mutex;
use rustc_hash::FxHashMap;

use crate::key::DatabaseKeyIndex;
use crate::runtime::{BlockResult, WaitResult};
use crate::table::util;
use crate::zalsa::{MemoIngredientIndex, Zalsa};
use crate::Database;
use crate::{
key::DatabaseKeyIndex,
runtime::{BlockResult, WaitResult},
zalsa::Zalsa,
Database, Id, IngredientIndex,
};

/// Tracks the keys that are currently being processed; used to coordinate between
/// worker threads.
#[derive(Default)]
pub(crate) struct SyncTable {
syncs: Mutex<Vec<Option<SyncState>>>,
syncs: Mutex<FxHashMap<Id, SyncState>>,
ingredient: IngredientIndex,
}

pub(crate) enum ClaimResult<'a> {
Retry,
Cycle,
Claimed(ClaimGuard<'a>),
}

struct SyncState {
Expand All @@ -23,59 +31,56 @@ struct SyncState {
anyone_waiting: bool,
}

pub(crate) enum ClaimResult<'a> {
Retry,
Cycle,
Claimed(ClaimGuard<'a>),
}

impl SyncTable {
#[inline]
pub(crate) fn claim<'me>(
pub(crate) fn new(ingredient: IngredientIndex) -> Self {
Self {
syncs: Default::default(),
ingredient,
}
}

pub(crate) fn try_claim<'me>(
&'me self,
db: &'me (impl ?Sized + Database),
zalsa: &'me Zalsa,
database_key_index: DatabaseKeyIndex,
memo_ingredient_index: MemoIngredientIndex,
key_index: Id,
) -> ClaimResult<'me> {
let mut syncs = self.syncs.lock();
let thread_id = std::thread::current().id();

util::ensure_vec_len(&mut syncs, memo_ingredient_index.as_usize() + 1);

match &mut syncs[memo_ingredient_index.as_usize()] {
None => {
syncs[memo_ingredient_index.as_usize()] = Some(SyncState {
id: thread_id,
anyone_waiting: false,
});
ClaimResult::Claimed(ClaimGuard {
database_key_index,
memo_ingredient_index,
zalsa,
sync_table: self,
_padding: false,
})
}
Some(SyncState {
id: other_id,
anyone_waiting,
}) => {
let mut write = self.syncs.lock();
match write.entry(key_index) {
std::collections::hash_map::Entry::Occupied(occupied_entry) => {
let &mut SyncState {
id,
ref mut anyone_waiting,
} = occupied_entry.into_mut();
// NB: `Ordering::Relaxed` is sufficient here,
// as there are no loads that are "gated" on this
// value. Everything that is written is also protected
// by a lock that must be acquired. The role of this
// boolean is to decide *whether* to acquire the lock,
// not to gate future atomic reads.
*anyone_waiting = true;
match zalsa
.runtime()
.block_on(db, database_key_index, *other_id, syncs)
{
match zalsa.runtime().block_on(
db,
DatabaseKeyIndex::new(self.ingredient, key_index),
id,
write,
) {
BlockResult::Completed => ClaimResult::Retry,
BlockResult::Cycle => ClaimResult::Cycle,
}
}
std::collections::hash_map::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(SyncState {
id: std::thread::current().id(),
anyone_waiting: false,
});
ClaimResult::Claimed(ClaimGuard {
key_index,
zalsa,
sync_table: self,
_padding: false,
})
}
}
}
}
Expand All @@ -84,8 +89,7 @@ impl SyncTable {
/// released when this value is dropped.
#[must_use]
pub(crate) struct ClaimGuard<'me> {
database_key_index: DatabaseKeyIndex,
memo_ingredient_index: MemoIngredientIndex,
key_index: Id,
zalsa: &'me Zalsa,
sync_table: &'me SyncTable,
// Reduce the size of ClaimResult by making more niches available in ClaimGuard; this fits into
Expand All @@ -97,12 +101,11 @@ impl ClaimGuard<'_> {
fn remove_from_map_and_unblock_queries(&self) {
let mut syncs = self.sync_table.syncs.lock();

let SyncState { anyone_waiting, .. } =
syncs[self.memo_ingredient_index.as_usize()].take().unwrap();
let SyncState { anyone_waiting, .. } = syncs.remove(&self.key_index).unwrap();

if anyone_waiting {
self.zalsa.runtime().unblock_queries_blocked_on(
self.database_key_index,
DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index),
if std::thread::panicking() {
WaitResult::Panicked
} else {
Expand Down
10 changes: 0 additions & 10 deletions src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::input::singleton::{Singleton, SingletonChoice};
use crate::key::DatabaseKeyIndex;
use crate::plumbing::{Jar, Stamp};
use crate::table::memo::{MemoTable, MemoTableTypes};
use crate::table::sync::SyncTable;
use crate::table::{Slot, Table};
use crate::zalsa::{IngredientIndex, Zalsa};
use crate::{Database, Durability, Id, Revision, Runtime};
Expand Down Expand Up @@ -107,7 +106,6 @@ impl<C: Configuration> IngredientImpl<C> {
fields,
stamps,
memos: Default::default(),
syncs: Default::default(),
})
});

Expand Down Expand Up @@ -252,9 +250,6 @@ where

/// Memos
memos: MemoTable,

/// Syncs
syncs: SyncTable,
}

impl<C> Value<C>
Expand Down Expand Up @@ -288,9 +283,4 @@ where
fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable {
&mut self.memos
}

#[inline(always)]
unsafe fn syncs(&self, _current_revision: Revision) -> &SyncTable {
&self.syncs
}
}
8 changes: 0 additions & 8 deletions src/interned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use crate::ingredient::{fmt_index, Ingredient};
use crate::plumbing::{IngredientIndices, Jar};
use crate::revision::AtomicRevision;
use crate::table::memo::{MemoTable, MemoTableTypes};
use crate::table::sync::SyncTable;
use crate::table::Slot;
use crate::zalsa::{IngredientIndex, Zalsa};
use crate::{Database, DatabaseKeyIndex, Event, EventKind, Id, Revision};
Expand Down Expand Up @@ -78,7 +77,6 @@ where
{
fields: C::Fields<'static>,
memos: MemoTable,
syncs: SyncTable,

/// The revision the value was first interned in.
first_interned_at: Revision,
Expand Down Expand Up @@ -311,7 +309,6 @@ where
let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value::<C> {
fields: unsafe { self.to_internal_data(assemble(id, key)) },
memos: Default::default(),
syncs: Default::default(),
durability: AtomicU8::new(durability.as_u8()),
// Record the revision we are interning in.
first_interned_at: current_revision,
Expand Down Expand Up @@ -463,11 +460,6 @@ where
fn memos_mut(&mut self) -> &mut MemoTable {
&mut self.memos
}

#[inline(always)]
unsafe fn syncs(&self, _current_revision: Revision) -> &crate::table::sync::SyncTable {
&self.syncs
}
}

/// A trait for types that hash and compare like `O`.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl Runtime {
other_id: ThreadId,
query_mutex_guard: QueryMutexGuard,
) -> BlockResult {
let mut dg = self.dependency_graph.lock();
let dg = self.dependency_graph.lock();
let thread_id = std::thread::current().id();

if dg.depends_on(other_id, thread_id) {
Expand Down
Loading