Skip to content

Commit

Permalink
Tests: Add rust-side python testing for rust-native circuit creation
Browse files Browse the repository at this point in the history
- Temporarily fix incorrect null assignment of `Bit` object to indices. A different solution can be achieved if more time is spent.
- Incorrect python circuit creation in Python due to misusage of the bit's `BitInfo` property inside of `BitData`.
- Fix incorrect register and bit creation from `CircuitData::new()`
- Add provisional `qubits_mut()` and `clbits_mut()` methods while we address the issue with `indices`.
  • Loading branch information
raynelfss committed Feb 1, 2025
1 parent f1ffc16 commit ba8bcf8
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 28 deletions.
61 changes: 45 additions & 16 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,36 @@ where
{
/// Finds the native bit index of the given Python bit.
#[inline]
pub fn py_find_bit(&self, bit: &Bound<PyAny>) -> Option<T> {
self.indices.get(&BitAsKey::new(bit)).copied()
pub fn py_find_bit(&mut self, bit: &Bound<PyAny>) -> PyResult<Option<T>> {
let key = &BitAsKey::new(bit);
if let Some(value) = self.indices.get(key).copied() {
Ok(Some(value))
} else if self.indices.len() != self.len() {
let py = bit.py();
// TODO: Make sure all bits have been mapped during addition.
// The only case in which bits may not be mapped is when a circuit
// is created entirely from Rust. For which case the Python representations
// of the bits still don't exist.
// Ideally, we'd want to initialize the mapping with a value representing the
// future bit, but it is hard to come up with a way to represent an object
// that doesn't exist yet.
// So for now, perform a check if the length of the mapped indices differs
// from the number of bits available.
for index in 0..self.len() {
let index = T::from(index.try_into().map_err(|_| {
CircuitError::new_err(format!(
"This circuit's {} has exceeded its length limit",
self.description
))
})?);
let bit = self.py_get_bit(py, index)?.unwrap();
let bit_as_key = BitAsKey::new(bit.bind(py));
self.indices.entry(bit_as_key).or_insert(index);
}
Ok(self.indices.get(key).copied())
} else {
Ok(None)
}
}

/// Gets a reference to the cached Python list, with the bits maintained by
Expand Down Expand Up @@ -530,11 +558,11 @@ where

/// Gets the location of a bit within the circuit
pub fn py_get_bit_location(
&self,
&mut self,
bit: &Bound<PyAny>,
) -> PyResult<(u32, Vec<(&PyObject, u32)>)> {
let py = bit.py();
let index = self.py_find_bit(bit).ok_or(PyKeyError::new_err(format!(
let index = self.py_find_bit(bit)?.ok_or(PyKeyError::new_err(format!(
"The provided {} is not part of this circuit",
self.description
)))?;
Expand Down Expand Up @@ -649,12 +677,19 @@ where
RegisterAsKey::Quantum(_) => QUANTUM_REGISTER.get_bound(py),
RegisterAsKey::Classical(_) => CLASSICAL_REGISTER.get_bound(py),
};
// Check if any indices have been initialized, if such is the case
// Treat the rest of indices as new `Bits``
if register
.bits()
.any(|bit| self.bits[BitType::from(bit) as usize].get().is_some())
{
// Check if all indices have been initialized from this register, if such is the case
// Treat the rest of indices as old `Bits``
if register.bits().all(|bit| {
self.bit_info[BitType::from(bit) as usize]
.orig_register_index()
.is_some_and(|idx| idx.register_index() == index)
}) {
let reg = reg_type.call1((register.len(), register.name()))?;
self.registers[index_as_usize]
.set(reg.into())
.map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?;
Ok(self.registers[index_as_usize].get())
} else {
let bits: Vec<PyObject> = register
.bits()
.map(|bit| -> PyResult<PyObject> {
Expand All @@ -677,12 +712,6 @@ where
.set(reg.into())
.map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?;
Ok(self.registers[index_as_usize].get())
} else {
let reg = reg_type.call1((register.len(), register.name()))?;
self.registers[index_as_usize]
.set(reg.into())
.map_err(|_| PyRuntimeError::new_err("Could not set the OnceCell correctly"))?;
Ok(self.registers[index_as_usize].get())
}
} else {
Ok(self.registers[index_as_usize].get())
Expand Down
157 changes: 148 additions & 9 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl CircuitData {
/// Gets the location of the bit inside of the circuit
#[pyo3(name = "get_qubit_location")]
pub fn py_get_qubit_location(
&self,
&mut self,
bit: &Bound<PyAny>,
) -> PyResult<(u32, Vec<(&PyObject, u32)>)> {
self.qubits.py_get_bit_location(bit)
Expand Down Expand Up @@ -266,7 +266,7 @@ impl CircuitData {
/// Gets the location of the bit inside of the circuit
#[pyo3(name = "get_clbit_location")]
pub fn py_get_clbit_location(
&self,
&mut self,
bit: &Bound<PyAny>,
) -> PyResult<(u32, Vec<(&PyObject, u32)>)> {
self.clbits.py_get_bit_location(bit)
Expand Down Expand Up @@ -799,7 +799,7 @@ impl CircuitData {
.map(|b| {
Ok(self
.qubits
.py_find_bit(other.qubits.py_get_bit(py, *b)?.unwrap().bind(py))
.py_find_bit(other.qubits.py_get_bit(py, *b)?.unwrap().bind(py))?
.unwrap())
})
.collect::<PyResult<Vec<Qubit>>>()?;
Expand All @@ -810,7 +810,7 @@ impl CircuitData {
.map(|b| {
Ok(self
.clbits
.py_find_bit(other.clbits.py_get_bit(py, *b)?.unwrap().bind(py))
.py_find_bit(other.clbits.py_get_bit(py, *b)?.unwrap().bind(py))?
.unwrap())
})
.collect::<PyResult<Vec<Clbit>>>()?;
Expand Down Expand Up @@ -1078,20 +1078,34 @@ impl CircuitData {
};
// Add all the bits into a register
if add_qreg {
let indices: Vec<Qubit> = (0..num_qubits).map(|_| data.add_qubit()).collect();
data.add_qreg(Some("q".to_string()), None, Some(&indices));
data.add_qreg(
Some("q".to_string()),
Some(
num_qubits
.try_into()
.expect("The number of qubits provided exceeds the limit for a circuit."),
),
None,
);
} else {
(0..num_qubits).for_each(|_| {
data.add_qubit();
});
}
// Add all the bits into a register
if add_creg {
let indices: Vec<Clbit> = (0..num_clbits).map(|_| data.add_clbit()).collect();
data.add_creg(Some("c".to_string()), None, Some(&indices));
data.add_creg(
Some("c".to_string()),
Some(
num_clbits
.try_into()
.expect("The number of qubits provided exceeds the limit for a circuit."),
),
None,
);
} else {
(0..num_clbits).for_each(|_| {
data.add_qubit();
data.add_clbit();
});
}
data
Expand Down Expand Up @@ -1554,6 +1568,18 @@ impl CircuitData {
&self.clbits
}

// TODO: Remove
/// Returns a mutable view of the Qubits registered in the circuit
pub fn qubits_mut(&mut self) -> &mut NewBitData<Qubit, QuantumRegister> {
&mut self.qubits
}

// TODO: Remove
/// Returns a mutable view of the Classical bits registered in the circuit
pub fn clbits_mut(&mut self) -> &mut NewBitData<Clbit, ClassicalRegister> {
&mut self.clbits
}

/// Unpacks from interned value to `[Qubit]`
pub fn get_qargs(&self, index: Interned<[Qubit]>) -> &[Qubit] {
self.qargs_interner().get(index)
Expand Down Expand Up @@ -1869,3 +1895,116 @@ mod test {
assert_eq!(cregs, expected_cregs)
}
}

#[cfg(all(test, not(miri)))]
// #[cfg(all(test))]
mod pytest {
use pyo3::PyTypeInfo;

use super::*;

// Test Rust native circuit construction when accessed through Python, without
// adding resgisters to the circuit.
#[test]
fn test_circuit_construction_py_no_regs() {
let num_qubits = 4;
let num_clbits = 3;
let circuit_data =
CircuitData::new(num_qubits, num_clbits, Param::Float(0.0), false, false);
let result = Python::with_gil(|py| -> PyResult<bool> {
let quantum_circuit = QUANTUM_CIRCUIT.get_bound(py).clone();

let converted_circuit =
quantum_circuit.call_method1("_from_circuit_data", (circuit_data,))?;

let converted_qregs = converted_circuit.getattr("qregs")?;
println!("{}", converted_qregs);
assert!(converted_qregs.is_instance(&PyList::type_object(py))?);
assert!(
converted_qregs.downcast::<PyList>()?.len() == 0,
"The quantum registers list returned a non-empty value"
);

let converted_qubits = converted_circuit.getattr("qubits")?;
println!("{:?}", converted_qubits);
assert!(converted_qubits.is_instance(&PyList::type_object(py))?);
assert!(
converted_qubits.downcast::<PyList>()?.len() == (num_qubits as usize),
"The qubits has the wrong length"
);

let converted_qregs = converted_circuit.getattr("qregs")?;
println!("{}", converted_qregs);
assert!(converted_qregs.is_instance(&PyList::type_object(py))?);
assert!(
converted_qregs.downcast::<PyList>()?.len() == 0,
"The classical registers list returned a non-empty value"
);

let converted_clbits = converted_circuit.getattr("clbits")?;
println!("{:?}", converted_clbits);
assert!(converted_clbits.is_instance(&PyList::type_object(py))?);
assert!(
converted_clbits.downcast::<PyList>()?.len() == (num_clbits as usize),
"The clbits has the wrong length"
);

Ok(true)
})
.is_ok_and(|res| res);
assert!(result);
}

#[test]
fn test_circuit_construction() {
let num_qubits = 4;
let num_clbits = 3;
let circuit_data = CircuitData::new(num_qubits, num_clbits, Param::Float(0.0), true, true);
let result = Python::with_gil(|py| -> PyResult<bool> {
let quantum_circuit = QUANTUM_CIRCUIT.get_bound(py).clone();

let converted_circuit =
quantum_circuit.call_method1("_from_circuit_data", (circuit_data,))?;
let expected_circuit = quantum_circuit.call1((num_qubits, num_clbits))?;

let converted_qregs = converted_circuit.getattr("qregs")?;
let expected_qregs = expected_circuit.getattr("qregs")?;

println!("{:?} vs {:?}", converted_qregs, expected_qregs);

assert!(converted_qregs.eq(expected_qregs)?);

let converted_cregs = converted_circuit.getattr("cregs")?;
let expected_cregs = expected_circuit.getattr("cregs")?;

println!("{:?} vs {:?}", converted_cregs, expected_cregs);

assert!(converted_cregs.eq(expected_cregs)?);

let converted_qubits = converted_circuit.getattr("qubits")?;
let expected_qubits = expected_circuit.getattr("qubits")?;
println!("{:?} vs {:?}", converted_qubits, expected_qubits);
assert!(converted_qubits.eq(&expected_qubits)?);

let converted_clbits = converted_circuit.getattr("clbits")?;
let expected_clbits = expected_circuit.getattr("clbits")?;
println!("{:?} vs {:?}", converted_clbits, expected_clbits);
assert!(converted_clbits.eq(&expected_clbits)?);

let converted_global_phase = converted_circuit.getattr("global_phase")?;
let expected_global_phase = expected_circuit.getattr("global_phase")?;
println!(
"{:?} vs {:?}",
converted_global_phase, expected_global_phase
);
assert!(converted_global_phase.eq(&expected_global_phase)?);

// TODO: Figure out why this fails
// converted_circuit.eq(expected_circuit)

Ok(true)
})
.is_ok_and(|res| res);
assert!(result);
}
}
6 changes: 3 additions & 3 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6624,7 +6624,7 @@ impl DAGCircuit {
clbit_order: Option<Vec<Bound<PyAny>>>,
) -> PyResult<DAGCircuit> {
// Extract necessary attributes
let qc_data = qc.data;
let mut qc_data = qc.data;
let num_qubits = qc_data.num_qubits();
let num_clbits = qc_data.num_clbits();
let num_ops = qc_data.__len__();
Expand Down Expand Up @@ -6668,7 +6668,7 @@ impl DAGCircuit {
&qubit
)));
}
let qubit_index = qc_data.qubits().py_find_bit(&qubit).unwrap();
let qubit_index = qc_data.qubits_mut().py_find_bit(&qubit)?.unwrap();
ordered_vec[qubit_index.index()] = new_dag.add_qubit_unchecked(py, &qubit)?;
Ok(())
})?;
Expand Down Expand Up @@ -6701,7 +6701,7 @@ impl DAGCircuit {
&clbit
)));
};
let clbit_index = qc_data.clbits().py_find_bit(&clbit).unwrap();
let clbit_index = qc_data.clbits_mut().py_find_bit(&clbit)?.unwrap();
ordered_vec[clbit_index.index()] = new_dag.add_clbit_unchecked(py, &clbit)?;
Ok(())
})?;
Expand Down

0 comments on commit ba8bcf8

Please sign in to comment.