Skip to content

Commit

Permalink
Add: compose_transform logic in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Sep 10, 2024
1 parent f9e97ea commit 3a3e734
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 206 deletions.
232 changes: 126 additions & 106 deletions crates/accelerate/src/basis/basis_translator/compose_transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@

use hashbrown::{HashMap, HashSet};
use once_cell::sync::Lazy;
use pyo3::types::PyTuple;
use pyo3::{exceptions::PyTypeError, prelude::*};
use qiskit_circuit::imports::{CIRCUIT_TO_DAG, QUANTUM_CIRCUIT};
use qiskit_circuit::imports::{
CIRCUIT_TO_DAG, GATE, PARAMETER_VECTOR, QUANTUM_CIRCUIT, QUANTUM_REGISTER,
};
use qiskit_circuit::operations::OperationRef;
use qiskit_circuit::parameter_table::ParameterUuid;
use qiskit_circuit::{
circuit_data::CircuitData,
dag_circuit::{DAGCircuit, NodeType},
Expand All @@ -37,39 +41,15 @@ static CONTROL_FLOW_OP_NAMES: Lazy<HashSet<&'static str>> = Lazy::new(|| {
/// Representation of QuantumCircuit which the original circuit object + an
/// instance of `CircuitData`.
#[derive(Debug, Clone)]
pub struct CircuitRep {
object: PyObject,
pub num_qubits: usize,
pub num_clbits: usize,
pub data: CircuitData,
}

impl CircuitRep {
/// Performs a shallow cloning of the structure by using `clone_ref()`.
pub fn py_clone(&self, py: Python) -> Self {
Self {
object: self.object.clone_ref(py),
num_qubits: self.num_qubits,
num_clbits: self.num_clbits,
data: self.data.clone(),
}
}
}
pub struct CircuitRep(pub CircuitData);

impl FromPyObject<'_> for CircuitRep {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if ob.is_instance(QUANTUM_CIRCUIT.get_bound(ob.py()))? {
let data: Bound<PyAny> = ob.getattr("_data")?;
let data_downcast: Bound<CircuitData> = data.downcast_into()?;
let data_extract: CircuitData = data_downcast.extract()?;
let num_qubits: usize = data_extract.num_qubits();
let num_clbits: usize = data_extract.num_clbits();
Ok(Self {
object: ob.into_py(ob.py()),
num_qubits,
num_clbits,
data: data_extract,
})
Ok(Self(data_extract))
} else {
Err(PyTypeError::new_err(
"Provided object was not an instance of QuantumCircuit",
Expand All @@ -79,97 +59,137 @@ impl FromPyObject<'_> for CircuitRep {
}

impl IntoPy<PyObject> for CircuitRep {
fn into_py(self, _py: Python<'_>) -> PyObject {
self.object
fn into_py(self, py: Python<'_>) -> PyObject {
QUANTUM_CIRCUIT
.get_bound(py)
.call_method1("_from_circuit_data", (self.0,))
.unwrap()
.unbind()
}
}

impl ToPyObject for CircuitRep {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.object.clone_ref(py)
self.clone().into_py(py)
}
}

#[pyfunction(name = "compose_transforms")]
fn py_compose_transforms(
_py: Python,
_basis_transforms: BasisTransforms,
_source_basis: HashSet<(String, u32)>,
_source_dag: &DAGCircuit,
) -> HashMap<String, (SmallVec<[Param; 3]>, DAGCircuit)> {
todo!()
pub(super) fn py_compose_transforms(
py: Python,
basis_transforms: BasisTransforms,
source_basis: HashSet<(String, u32)>,
source_dag: &DAGCircuit,
) -> PyResult<HashMap<(String, u32), (SmallVec<[Param; 3]>, DAGCircuit)>> {
compose_transforms(py, &basis_transforms, &source_basis, source_dag).map(|ret| {
ret.into_iter()
.map(|((name, num_qubits), (param, equiv))| ((name, num_qubits), (param, equiv)))
.collect()
})
}

fn compose_transforms(
_basis_transforms: &BasisTransforms,
_source_basis: &HashSet<(String, u32)>,
_source_dag: DAGCircuit,
) -> HashMap<String, (SmallVec<[Param; 3]>, DAGCircuit)> {
/*
example_gates = _get_example_gates(source_dag)
mapped_instrs = {}
for gate_name, gate_num_qubits in source_basis:
# Need to grab a gate instance to find num_qubits and num_params.
# Can be removed following https://github.com/Qiskit/qiskit-terra/pull/3947 .
example_gate = example_gates[gate_name, gate_num_qubits]
num_params = len(example_gate.params)
placeholder_params = ParameterVector(gate_name, num_params)
placeholder_gate = Gate(gate_name, gate_num_qubits, list(placeholder_params))
placeholder_gate.params = list(placeholder_params)
dag = DAGCircuit()
qr = QuantumRegister(gate_num_qubits)
dag.add_qreg(qr)
dag.apply_operation_back(placeholder_gate, qr, (), check=False)
mapped_instrs[gate_name, gate_num_qubits] = placeholder_params, dag
for gate_name, gate_num_qubits, equiv_params, equiv in basis_transforms:
logger.debug(
"Composing transform step: %s/%s %s =>\n%s",
gate_name,
pub(super) fn compose_transforms<'a>(
py: Python,
basis_transforms: &'a BasisTransforms,
source_basis: &'a HashSet<(String, u32)>,
source_dag: &'a DAGCircuit,
) -> PyResult<HashMap<(String, u32), (SmallVec<[Param; 3]>, DAGCircuit)>> {
let example_gates = *get_example_gates(py, source_dag, None)?;
let mut mapped_instructions: HashMap<(String, u32), (SmallVec<[Param; 3]>, DAGCircuit)> =
HashMap::new();

for (gate_name, gate_num_qubits) in source_basis.iter().cloned() {
// Need to grab a gate instance to find num_qubits and num_params.
// Can be removed following https://github.com/Qiskit/qiskit-terra/pull/3947 .
let Some(NodeType::Operation(example_gate)) = source_dag
.dag
.node_weight(example_gates[&(gate_name.clone(), gate_num_qubits)])
else {
panic!(
"Nodeindex {:?} should be in the source_dag",
example_gates[&(gate_name, gate_num_qubits)]
)
};
let num_params = example_gate
.params
.as_ref()
.map(|x| x.len())
.unwrap_or_default();

let placeholder_params: SmallVec<[Param; 3]> = PARAMETER_VECTOR
.get_bound(py)
.call1((&gate_name, num_params))?
.extract()?;

let mut dag = DAGCircuit::new(py)?;
// Create the mock gate and add to the circuit, use Python for this.
let qubits = QUANTUM_REGISTER.get_bound(py).call1((gate_num_qubits,))?;
dag.add_qreg(py, &qubits)?;

let gate = GATE.get_bound(py).call1((
&gate_name,
gate_num_qubits,
equiv_params,
equiv,
)
for mapped_instr_name, (dag_params, dag) in mapped_instrs.items():
doomed_nodes = [
node
for node in dag.op_nodes()
if (node.name, node.num_qubits) == (gate_name, gate_num_qubits)
]
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Updating transform for mapped instr %s %s from \n%s",
mapped_instr_name,
dag_params,
dag_to_circuit(dag, copy_operations=False),
)
for node in doomed_nodes:
replacement = equiv.assign_parameters(dict(zip_longest(equiv_params, node.params)))
replacement_dag = circuit_to_dag(replacement)
dag.substitute_node_with_dag(node, replacement_dag)
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Updated transform for mapped instr %s %s to\n%s",
mapped_instr_name,
dag_params,
dag_to_circuit(dag, copy_operations=False),
)
return mapped_instrs
*/
todo!()
placeholder_params
.iter()
.map(|x| x.clone_ref(py))
.collect::<SmallVec<[Param; 3]>>(),
))?;

dag.py_apply_operation_back(
py,
gate,
Some(PyTuple::new_bound(py, 0..gate_num_qubits).extract()?),
None,
true,
)?;
mapped_instructions.insert(
(gate_name.clone(), gate_num_qubits),
(placeholder_params, dag),
);

for (_gate_name, _gate_num_qubitss, equiv_params, equiv) in basis_transforms {
for ((_mapped_instr_name, _), (_dag_params, dag)) in &mut mapped_instructions {
let doomed_nodes = dag
.op_nodes(true)
.filter_map(|node| {
if let Some(NodeType::Operation(op)) = dag.dag.node_weight(node) {
Some((
node,
op.params_view()
.iter()
.map(|x| x.clone_ref(py))
.collect::<SmallVec<[Param; 3]>>(),
))
} else {
None
}
})
.collect::<Vec<_>>();
for (node, params) in doomed_nodes {
let param_mapping: HashMap<ParameterUuid, Param> = equiv_params
.iter()
.map(|x| ParameterUuid::from_parameter(x.to_object(py).bind(py)))
.zip(params)
.map(|(uuid, param)| -> PyResult<(ParameterUuid, Param)> {
Ok((uuid?, param.clone_ref(py)))
})
.collect::<PyResult<_>>()?;
let mut replacement = equiv.clone();
replacement
.0
.assign_parameters_from_mapping(py, param_mapping)?;
let replace_dag: DAGCircuit = CIRCUIT_TO_DAG
.get_bound(py)
.call1((replacement,))?
.downcast_into::<DAGCircuit>()?
.extract()?;
let op_node = dag.get_node(py, node)?;
dag.substitute_node_with_dag(py, op_node.bind(py), &replace_dag, None, true)?;
}
}
}
}
Ok(mapped_instructions)
}

fn get_example_gates(
Expand Down
3 changes: 2 additions & 1 deletion crates/accelerate/src/basis/basis_translator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use pyo3::prelude::*;
mod compose_transforms;

#[pymodule]
pub fn basis_translator(_m: &Bound<PyModule>) -> PyResult<()> {
pub fn basis_translator(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(compose_transforms::py_compose_transforms))?;
Ok(())
}
6 changes: 3 additions & 3 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def _format(operand):
}

/// Add all wires in a quantum register.
fn add_qreg(&mut self, py: Python, qreg: &Bound<PyAny>) -> PyResult<()> {
pub fn add_qreg(&mut self, py: Python, qreg: &Bound<PyAny>) -> PyResult<()> {
if !qreg.is_instance(imports::QUANTUM_REGISTER.get_bound(py))? {
return Err(DAGCircuitError::new_err("not a QuantumRegister instance."));
}
Expand Down Expand Up @@ -1668,7 +1668,7 @@ def _format(operand):
/// Raises:
/// DAGCircuitError: if a leaf node is connected to multiple outputs
#[pyo3(name = "apply_operation_back", signature = (op, qargs=None, cargs=None, *, check=true))]
fn py_apply_operation_back(
pub fn py_apply_operation_back(
&mut self,
py: Python,
op: Bound<PyAny>,
Expand Down Expand Up @@ -2870,7 +2870,7 @@ def _format(operand):
/// Raises:
/// DAGCircuitError: if met with unexpected predecessor/successors
#[pyo3(signature = (node, input_dag, wires=None, propagate_condition=true))]
fn substitute_node_with_dag(
pub fn substitute_node_with_dag(
&mut self,
py: Python,
node: &Bound<PyAny>,
Expand Down
2 changes: 2 additions & 0 deletions crates/circuit/src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ pub static CLASSICAL_REGISTER: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.classicalregister", "ClassicalRegister");
pub static PARAMETER_EXPRESSION: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.parameterexpression", "ParameterExpression");
pub static PARAMETER_VECTOR: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.parametervector", "ParameterVector");
pub static QUANTUM_CIRCUIT: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.quantumcircuit", "QuantumCircuit");
pub static SINGLETON_GATE: ImportOnceCell =
Expand Down
2 changes: 2 additions & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
# and not have to rely on attribute access. No action needed for top-level extension packages.
sys.modules["qiskit._accelerate.circuit"] = _accelerate.circuit
sys.modules["qiskit._accelerate.circuit_library"] = _accelerate.circuit_library
sys.modules["qiskit._accelerate.basis"] = _accelerate.basis
sys.modules["qiskit._accelerate.basis.basis_translator"] = _accelerate.basis.basis_translator
sys.modules["qiskit._accelerate.convert_2q_block_matrix"] = _accelerate.convert_2q_block_matrix
sys.modules["qiskit._accelerate.dense_layout"] = _accelerate.dense_layout
sys.modules["qiskit._accelerate.error_map"] = _accelerate.error_map
Expand Down
Loading

0 comments on commit 3a3e734

Please sign in to comment.