Skip to content

Commit

Permalink
Fix: Use RwLock for Bit indices mappings.
Browse files Browse the repository at this point in the history
- Re-purpose `NewBitData.indices` to be modifiable even when if a mutable references is not granted by using `RWLock`. This is done to stay consistent with the behavior of `OnceLock` which allows us to initialize bits upon request. We need to make sure to map a bit once it has been initialized. Otherwise, the circuit will have to regenerate this mapping multiple times during runtime.
- Optimize pre-caching of Python bits and registers by using the cache size. Bits and registers are guaranteed to be initialized in ther order they were added/last accessed, so in the case a cache is initialized prematurely, add the missing bits or registers to the back of the cache while following the same index order.
  • Loading branch information
raynelfss committed Feb 3, 2025
1 parent 7690065 commit c0f3267
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 76 deletions.
152 changes: 82 additions & 70 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
use std::sync::{OnceLock, RwLock};

/// Private wrapper for Python-side Bit instances that implements
/// [Hash] and [Eq], allowing them to be used in Rust hash-based
Expand Down Expand Up @@ -236,14 +236,15 @@ where
}
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct NewBitData<T: From<BitType>, R: Register + Hash + Eq> {
/// The public field name (i.e. `qubits` or `clbits`).
description: String,
/// Registered Python bits.
bits: Vec<OnceLock<PyObject>>,
/// Maps Python bits to native type.
indices: HashMap<BitAsKey, T>,
/// Maps Python bits to native type, should be modifiable upon
/// retrieval.
indices: RwLock<HashMap<BitAsKey, T>>,
/// Maps Register keys to indices
reg_keys: HashMap<RegisterAsKey, u32>,
/// Mapping between bit index and its register info
Expand Down Expand Up @@ -273,7 +274,7 @@ where
NewBitData {
description,
bits: Vec::new(),
indices: HashMap::new(),
indices: HashMap::new().into(),
bit_info: Vec::new(),
registry: Vec::new(),
registers: Vec::new(),
Expand All @@ -287,7 +288,7 @@ where
NewBitData {
description,
bits: Vec::with_capacity(bit_capacity),
indices: HashMap::with_capacity(bit_capacity),
indices: HashMap::with_capacity(bit_capacity).into(),
bit_info: Vec::with_capacity(bit_capacity),
registry: Vec::with_capacity(reg_capacity),
registers: Vec::with_capacity(reg_capacity),
Expand Down Expand Up @@ -454,35 +455,15 @@ where
/// Finds the native bit index of the given Python bit.
#[inline]
pub fn py_find_bit(&mut self, bit: &Bound<PyAny>) -> PyResult<Option<T>> {
let key = &BitAsKey::new(bit);
if let Some(value) = self.indices.get(key).copied() {
Ok(Some(value))
} else if self.indices.len() < self.len() {
let py = bit.py();
// TODO: Make sure all bits have been mapped during addition.
// The only case in which bits may not be mapped is when a circuit
// is created entirely from Rust. For which case the Python representations
// of the bits still don't exist.
// Ideally, we'd want to initialize the mapping with a value representing the
// future bit, but it is hard to come up with a way to represent an object
// that doesn't exist yet.
// So for now, perform a check if the length of the mapped indices differs
// from the number of bits available.
for index in 0..self.len() {
let index = T::from(index.try_into().map_err(|_| {
CircuitError::new_err(format!(
"This circuit's {} has exceeded its length limit",
self.description
))
})?);
let bit = self.py_get_bit(py, index)?.unwrap();
let bit_as_key = BitAsKey::new(bit.bind(py));
self.indices.entry(bit_as_key).or_insert(index);
}
Ok(self.indices.get(key).copied())
} else {
Ok(None)
}
self.indices
.try_read()
.map(|op| op.get(&BitAsKey::new(bit)).copied())
.map_err(|_| {
PyRuntimeError::new_err(format!(
"Could not map {}. Error accessing index mapping.",
&bit
))
})
}

/// Gets a reference to the cached Python list, with the bits maintained by
Expand All @@ -498,21 +479,16 @@ where
.into()
});

// If the length is different from the stored bits, rebuild cache
// If the length is different from the stored bits, append to cache
// Indices are guaranteed to follow
let res_as_bound = res.bind(py);
if res_as_bound.len() < self.len() {
let current_length = res_as_bound.len();
for index in 0..self.len() {
if index < current_length {
res_as_bound.set_item(index, self.py_get_bit(py, (index as u32).into())?)?
} else {
res_as_bound.append(self.py_get_bit(py, (index as u32).into())?)?
}
for index in current_length.checked_sub(1).unwrap_or_default()..self.len() {
res_as_bound.append(self.py_get_bit(py, (index as u32).into())?)?
}
Ok(self.cached_py_bits.get().unwrap())
} else {
Ok(res)
}
Ok(res)
}

/// Gets a reference to the cached Python list, with the registers maintained by
Expand All @@ -532,24 +508,11 @@ where
let res_as_bound = res.bind(py);
if res_as_bound.len() < self.len_regs() {
let current_length = res_as_bound.len();
for index in 0..self.len_regs() {
if index < current_length {
res_as_bound.set_item(index, self.py_get_register(py, index as u32)?)?
} else {
res_as_bound.append(self.py_get_register(py, index as u32)?)?
}
for index in (current_length - 1)..self.len_regs() {
res_as_bound.append(self.py_get_register(py, index as u32)?)?
}
let trimmed = res_as_bound.get_slice(0, self.len_regs()).unbind();
self.cached_py_regs.set(trimmed).map_err(|_| {
PyRuntimeError::new_err(format!(
"Tried to initialized {} register cache while another thread was initializing",
self.description
))
})?;
Ok(self.cached_py_regs.get().unwrap())
} else {
Ok(res)
}
Ok(res)
}

/// Gets a reference to the underlying vector of Python bits.
Expand Down Expand Up @@ -644,6 +607,10 @@ where
// A register index is guaranteed to exist in the instance of `BitData`.
let py_reg = self.py_get_register(py, bit_info.register_index())?;
let res = py_reg.unwrap().bind(py).get_item(bit_info.index())?;
self.indices
.try_write()
.map(|mut indices| indices.insert(BitAsKey::new(&res), index))
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?;
self.bits[index_as_usize]
.set(res.into())
.map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?;
Expand All @@ -653,8 +620,13 @@ where
} else if let Some(bit) = self.bits[index_as_usize].get() {
Ok(Some(bit))
} else {
let new_bit = T::to_py_bit(py)?;
self.indices
.try_write()
.map(|mut indices| indices.insert(BitAsKey::new(new_bit.bind(py)), index))
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?;
self.bits[index_as_usize]
.set(T::to_py_bit(py)?)
.set(new_bit)
.map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?;
Ok(self.bits[index_as_usize].get())
}
Expand Down Expand Up @@ -737,8 +709,9 @@ where
})?;
if self
.indices
.try_insert(BitAsKey::new(bit), idx.into())
.is_ok()
.try_write()
.map(|mut res| res.try_insert(BitAsKey::new(bit), idx.into()).is_ok())
.is_ok_and(|res| res)
{
// Append to cache before bits to avoid rebuilding cache.
self.py_cached_bits(py)?.bind(py).append(bit)?;
Expand Down Expand Up @@ -850,17 +823,26 @@ where
indices_sorted.sort();

for index in indices_sorted.into_iter().rev() {
self.py_cached_bits(py)?.bind(py).del_item(index)?;
self.cached_py_bits.take();
let bit = self.py_get_bit(py, (index as BitType).into())?.unwrap();
self.indices.remove(&BitAsKey::new(bit.bind(py)));
self.indices
.try_write()
.map(|mut op| op.remove(&BitAsKey::new(bit.bind(py))))
.map_err(|_| {
PyRuntimeError::new_err("Could not remove bit from cache".to_string())
})?;
self.bits.remove(index);
self.bit_info.remove(index);
}
// Update indices.
for i in 0..self.bits.len() {
let bit = self.py_get_bit(py, (i as BitType).into())?.unwrap();
self.indices
.insert(BitAsKey::new(bit.bind(py)), (i as BitType).into());
.try_write()
.map(|mut op| op.insert(BitAsKey::new(bit.bind(py)), (i as BitType).into()))
.map_err(|_| {
PyRuntimeError::new_err("Could not re-map bit in cache".to_string())
})?;
}
Ok(())
}
Expand All @@ -883,12 +865,16 @@ where

/// Called during Python garbage collection, only!.
/// Note: INVALIDATES THIS INSTANCE.
pub fn dispose(&mut self) {
self.indices.clear();
pub fn dispose(&mut self) -> PyResult<()> {
self.indices
.try_write()
.map(|mut op| op.clear())
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?;
self.bits.clear();
self.registers.clear();
self.bit_info.clear();
self.registry.clear();
Ok(())
}

/// To convert [BitData] into [NewBitData]. If the structure the original comes from contains register
Expand All @@ -901,7 +887,7 @@ where
.iter()
.map(|bit| bit.clone_ref(py).into())
.collect(),
indices: bit_data.indices.clone(),
indices: bit_data.indices.clone().into(),
reg_keys: HashMap::new(),
bit_info: (0..bit_data.len()).map(|_| BitInfo::new(None)).collect(),
registry: Vec::new(),
Expand All @@ -911,3 +897,29 @@ where
}
}
}

// Custom implementation of Clone due to RWLock usage.
impl<T: Clone, R: Clone> Clone for NewBitData<T, R>
where
T: From<u32> + Copy,
R: Register + Hash + Eq,
{
fn clone(&self) -> Self {
Self {
description: self.description.clone(),
bits: self.bits.clone(),
indices: self
.indices
.try_read()
.map(|indices| indices.clone())
.unwrap_or_default()
.into(),
reg_keys: self.reg_keys.clone(),
bit_info: self.bit_info.clone(),
registry: self.registry.clone(),
registers: self.registers.clone(),
cached_py_bits: self.cached_py_bits.clone(),
cached_py_regs: self.cached_py_regs.clone(),
}
}
}
14 changes: 8 additions & 6 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,12 +987,13 @@ impl CircuitData {
Ok(())
}

fn __clear__(&mut self) {
fn __clear__(&mut self) -> PyResult<()> {
// Clear anything that could have a reference cycle.
self.data.clear();
self.qubits.dispose();
self.clbits.dispose();
self.qubits.dispose()?;
self.clbits.dispose()?;
self.param_table.clear();
Ok(())
}

/// Set the global phase of the circuit.
Expand Down Expand Up @@ -1102,7 +1103,7 @@ impl CircuitData {
Some(
num_clbits
.try_into()
.expect("The number of qubits provided exceeds the limit for a circuit."),
.expect("The number of clbits provided exceeds the limit for a circuit."),
),
None,
);
Expand Down Expand Up @@ -1920,7 +1921,7 @@ mod pytest {
use super::*;

// Test Rust native circuit construction when accessed through Python, without
// adding resgisters to the circuit.
// adding registers to the circuit.
#[test]
fn test_circuit_construction_py_no_regs() {
let num_qubits = 4;
Expand Down Expand Up @@ -1971,6 +1972,7 @@ mod pytest {
assert!(result);
}

// Test Rust native circuit construction when accessed through Python.
#[test]
fn test_circuit_construction() {
let num_qubits = 4;
Expand Down Expand Up @@ -2016,7 +2018,7 @@ mod pytest {
assert!(converted_global_phase.eq(&expected_global_phase)?);

// TODO: Figure out why this fails
// converted_circuit.eq(expected_circuit)
// converted_circuit.eq(&expected_circuit)

Ok(true)
})
Expand Down

0 comments on commit c0f3267

Please sign in to comment.