diff --git a/crates/circuit/src/bit_data.rs b/crates/circuit/src/bit_data.rs index 62e4d2b35f85..30d888dfbaf0 100644 --- a/crates/circuit/src/bit_data.rs +++ b/crates/circuit/src/bit_data.rs @@ -21,7 +21,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::OnceLock; +use std::sync::{OnceLock, RwLock}; /// Private wrapper for Python-side Bit instances that implements /// [Hash] and [Eq], allowing them to be used in Rust hash-based @@ -236,14 +236,15 @@ where } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct NewBitData, R: Register + Hash + Eq> { /// The public field name (i.e. `qubits` or `clbits`). description: String, /// Registered Python bits. bits: Vec>, - /// Maps Python bits to native type. - indices: HashMap, + /// Maps Python bits to native type, should be modifiable upon + /// retrieval. + indices: RwLock>, /// Maps Register keys to indices reg_keys: HashMap, /// Mapping between bit index and its register info @@ -273,7 +274,7 @@ where NewBitData { description, bits: Vec::new(), - indices: HashMap::new(), + indices: HashMap::new().into(), bit_info: Vec::new(), registry: Vec::new(), registers: Vec::new(), @@ -287,7 +288,7 @@ where NewBitData { description, bits: Vec::with_capacity(bit_capacity), - indices: HashMap::with_capacity(bit_capacity), + indices: HashMap::with_capacity(bit_capacity).into(), bit_info: Vec::with_capacity(bit_capacity), registry: Vec::with_capacity(reg_capacity), registers: Vec::with_capacity(reg_capacity), @@ -454,35 +455,15 @@ where /// Finds the native bit index of the given Python bit. #[inline] pub fn py_find_bit(&mut self, bit: &Bound) -> PyResult> { - let key = &BitAsKey::new(bit); - if let Some(value) = self.indices.get(key).copied() { - Ok(Some(value)) - } else if self.indices.len() < self.len() { - let py = bit.py(); - // TODO: Make sure all bits have been mapped during addition. - // The only case in which bits may not be mapped is when a circuit - // is created entirely from Rust. For which case the Python representations - // of the bits still don't exist. - // Ideally, we'd want to initialize the mapping with a value representing the - // future bit, but it is hard to come up with a way to represent an object - // that doesn't exist yet. - // So for now, perform a check if the length of the mapped indices differs - // from the number of bits available. - for index in 0..self.len() { - let index = T::from(index.try_into().map_err(|_| { - CircuitError::new_err(format!( - "This circuit's {} has exceeded its length limit", - self.description - )) - })?); - let bit = self.py_get_bit(py, index)?.unwrap(); - let bit_as_key = BitAsKey::new(bit.bind(py)); - self.indices.entry(bit_as_key).or_insert(index); - } - Ok(self.indices.get(key).copied()) - } else { - Ok(None) - } + self.indices + .try_read() + .map(|op| op.get(&BitAsKey::new(bit)).copied()) + .map_err(|_| { + PyRuntimeError::new_err(format!( + "Could not map {}. Error accessing index mapping.", + &bit + )) + }) } /// Gets a reference to the cached Python list, with the bits maintained by @@ -498,21 +479,16 @@ where .into() }); - // If the length is different from the stored bits, rebuild cache + // If the length is different from the stored bits, append to cache + // Indices are guaranteed to follow let res_as_bound = res.bind(py); if res_as_bound.len() < self.len() { let current_length = res_as_bound.len(); - for index in 0..self.len() { - if index < current_length { - res_as_bound.set_item(index, self.py_get_bit(py, (index as u32).into())?)? - } else { - res_as_bound.append(self.py_get_bit(py, (index as u32).into())?)? - } + for index in current_length.checked_sub(1).unwrap_or_default()..self.len() { + res_as_bound.append(self.py_get_bit(py, (index as u32).into())?)? } - Ok(self.cached_py_bits.get().unwrap()) - } else { - Ok(res) } + Ok(res) } /// Gets a reference to the cached Python list, with the registers maintained by @@ -532,24 +508,11 @@ where let res_as_bound = res.bind(py); if res_as_bound.len() < self.len_regs() { let current_length = res_as_bound.len(); - for index in 0..self.len_regs() { - if index < current_length { - res_as_bound.set_item(index, self.py_get_register(py, index as u32)?)? - } else { - res_as_bound.append(self.py_get_register(py, index as u32)?)? - } + for index in (current_length - 1)..self.len_regs() { + res_as_bound.append(self.py_get_register(py, index as u32)?)? } - let trimmed = res_as_bound.get_slice(0, self.len_regs()).unbind(); - self.cached_py_regs.set(trimmed).map_err(|_| { - PyRuntimeError::new_err(format!( - "Tried to initialized {} register cache while another thread was initializing", - self.description - )) - })?; - Ok(self.cached_py_regs.get().unwrap()) - } else { - Ok(res) } + Ok(res) } /// Gets a reference to the underlying vector of Python bits. @@ -644,6 +607,10 @@ where // A register index is guaranteed to exist in the instance of `BitData`. let py_reg = self.py_get_register(py, bit_info.register_index())?; let res = py_reg.unwrap().bind(py).get_item(bit_info.index())?; + self.indices + .try_write() + .map(|mut indices| indices.insert(BitAsKey::new(&res), index)) + .map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?; self.bits[index_as_usize] .set(res.into()) .map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?; @@ -653,8 +620,13 @@ where } else if let Some(bit) = self.bits[index_as_usize].get() { Ok(Some(bit)) } else { + let new_bit = T::to_py_bit(py)?; + self.indices + .try_write() + .map(|mut indices| indices.insert(BitAsKey::new(new_bit.bind(py)), index)) + .map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?; self.bits[index_as_usize] - .set(T::to_py_bit(py)?) + .set(new_bit) .map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?; Ok(self.bits[index_as_usize].get()) } @@ -737,8 +709,9 @@ where })?; if self .indices - .try_insert(BitAsKey::new(bit), idx.into()) - .is_ok() + .try_write() + .map(|mut res| res.try_insert(BitAsKey::new(bit), idx.into()).is_ok()) + .is_ok_and(|res| res) { // Append to cache before bits to avoid rebuilding cache. self.py_cached_bits(py)?.bind(py).append(bit)?; @@ -850,9 +823,14 @@ where indices_sorted.sort(); for index in indices_sorted.into_iter().rev() { - self.py_cached_bits(py)?.bind(py).del_item(index)?; + self.cached_py_bits.take(); let bit = self.py_get_bit(py, (index as BitType).into())?.unwrap(); - self.indices.remove(&BitAsKey::new(bit.bind(py))); + self.indices + .try_write() + .map(|mut op| op.remove(&BitAsKey::new(bit.bind(py)))) + .map_err(|_| { + PyRuntimeError::new_err("Could not remove bit from cache".to_string()) + })?; self.bits.remove(index); self.bit_info.remove(index); } @@ -860,7 +838,11 @@ where for i in 0..self.bits.len() { let bit = self.py_get_bit(py, (i as BitType).into())?.unwrap(); self.indices - .insert(BitAsKey::new(bit.bind(py)), (i as BitType).into()); + .try_write() + .map(|mut op| op.insert(BitAsKey::new(bit.bind(py)), (i as BitType).into())) + .map_err(|_| { + PyRuntimeError::new_err("Could not re-map bit in cache".to_string()) + })?; } Ok(()) } @@ -883,12 +865,16 @@ where /// Called during Python garbage collection, only!. /// Note: INVALIDATES THIS INSTANCE. - pub fn dispose(&mut self) { - self.indices.clear(); + pub fn dispose(&mut self) -> PyResult<()> { + self.indices + .try_write() + .map(|mut op| op.clear()) + .map_err(|err| PyRuntimeError::new_err(format!("{:?}", err)))?; self.bits.clear(); self.registers.clear(); self.bit_info.clear(); self.registry.clear(); + Ok(()) } /// To convert [BitData] into [NewBitData]. If the structure the original comes from contains register @@ -901,7 +887,7 @@ where .iter() .map(|bit| bit.clone_ref(py).into()) .collect(), - indices: bit_data.indices.clone(), + indices: bit_data.indices.clone().into(), reg_keys: HashMap::new(), bit_info: (0..bit_data.len()).map(|_| BitInfo::new(None)).collect(), registry: Vec::new(), @@ -911,3 +897,29 @@ where } } } + +// Custom implementation of Clone due to RWLock usage. +impl Clone for NewBitData +where + T: From + Copy, + R: Register + Hash + Eq, +{ + fn clone(&self) -> Self { + Self { + description: self.description.clone(), + bits: self.bits.clone(), + indices: self + .indices + .try_read() + .map(|indices| indices.clone()) + .unwrap_or_default() + .into(), + reg_keys: self.reg_keys.clone(), + bit_info: self.bit_info.clone(), + registry: self.registry.clone(), + registers: self.registers.clone(), + cached_py_bits: self.cached_py_bits.clone(), + cached_py_regs: self.cached_py_regs.clone(), + } + } +} diff --git a/crates/circuit/src/circuit_data.rs b/crates/circuit/src/circuit_data.rs index f8fe8a34efb4..c47b96785976 100644 --- a/crates/circuit/src/circuit_data.rs +++ b/crates/circuit/src/circuit_data.rs @@ -987,12 +987,13 @@ impl CircuitData { Ok(()) } - fn __clear__(&mut self) { + fn __clear__(&mut self) -> PyResult<()> { // Clear anything that could have a reference cycle. self.data.clear(); - self.qubits.dispose(); - self.clbits.dispose(); + self.qubits.dispose()?; + self.clbits.dispose()?; self.param_table.clear(); + Ok(()) } /// Set the global phase of the circuit. @@ -1102,7 +1103,7 @@ impl CircuitData { Some( num_clbits .try_into() - .expect("The number of qubits provided exceeds the limit for a circuit."), + .expect("The number of clbits provided exceeds the limit for a circuit."), ), None, ); @@ -1920,7 +1921,7 @@ mod pytest { use super::*; // Test Rust native circuit construction when accessed through Python, without - // adding resgisters to the circuit. + // adding registers to the circuit. #[test] fn test_circuit_construction_py_no_regs() { let num_qubits = 4; @@ -1971,6 +1972,7 @@ mod pytest { assert!(result); } + // Test Rust native circuit construction when accessed through Python. #[test] fn test_circuit_construction() { let num_qubits = 4; @@ -2016,7 +2018,7 @@ mod pytest { assert!(converted_global_phase.eq(&expected_global_phase)?); // TODO: Figure out why this fails - // converted_circuit.eq(expected_circuit) + // converted_circuit.eq(&expected_circuit) Ok(true) })