diff --git a/Cargo.toml b/Cargo.toml index b6c886f0..69767903 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ lto = "thin" [workspace] resolver = "2" -members = ["tket2", "tket2-py", "compile-rewriter", "taso-optimiser"] +members = ["tket2", "tket2-py", "compile-rewriter", "taso-optimiser", "hugr2phir"] default-members = ["tket2"] [workspace.package] diff --git a/hugr2phir/Cargo.toml b/hugr2phir/Cargo.toml new file mode 100644 index 00000000..5d6c7eaa --- /dev/null +++ b/hugr2phir/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "hugr2phir" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.4.2", features = ["derive"] } +tket2 = { workspace = true } +quantinuum-hugr = { workspace = true } +rmp-serde = "1.1.2" +serde_json = "1.0.107" +itertools.workspace = true diff --git a/hugr2phir/src/main.rs b/hugr2phir/src/main.rs new file mode 100644 index 00000000..3525c09f --- /dev/null +++ b/hugr2phir/src/main.rs @@ -0,0 +1,81 @@ +mod normalize; + +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; + +use clap::Parser; + +use hugr::{ + hugr::views::{DescendantsGraph, HierarchyView}, + ops::{OpTag, OpTrait, OpType}, + Hugr, HugrView, +}; + +use tket2::phir::circuit_to_phir; + +#[derive(Parser, Debug)] +#[clap(version = "1.0", long_about = None)] +#[clap(about = "Convert from hugr msgpack serialized form to PHIR JSON.")] +#[command(long_about = "Sets the input file to use. It must be serialized HUGR.")] +struct CmdLineArgs { + /// Name of input file/folder + input: PathBuf, + /// Name of output file/folder + #[arg( + short, + long, + value_name = "FILE", + default_value = None, + help = "Sets the output file or folder. Defaults to the same as the input file with a .json extension." + )] + output: Option, +} + +fn main() { + let CmdLineArgs { input, output } = CmdLineArgs::parse(); + + let reader = BufReader::new(File::open(&input).unwrap()); + let output = output.unwrap_or_else(|| { + let mut output = input.clone(); + output.set_extension("json"); + output + }); + + let mut hugr: Hugr = rmp_serde::from_read(reader).unwrap(); + normalize::remove_identity_tuples(&mut hugr); + // DescendantsGraph::try_new(&hugr, root).unwrap() + let root = hugr.root(); + let root_op_tag = hugr.get_optype(root).tag(); + let circ: DescendantsGraph = if OpTag::DataflowParent.is_superset(root_op_tag) { + // Some dataflow graph + DescendantsGraph::try_new(&hugr, root).unwrap() + } else if OpTag::ModuleRoot.is_superset(root_op_tag) { + // Assume Guppy generated module + + // just take the first function + let main_node = hugr + .children(hugr.root()) + .find(|n| matches!(hugr.get_optype(*n), OpType::FuncDefn(_))) + .expect("Module contains no functions."); + // just take the first node again...assume guppy source so always top + // level CFG + let cfg_node = hugr + .children(main_node) + .find(|n| matches!(hugr.get_optype(*n), OpType::CFG(_))) + .expect("Function contains no cfg."); + + // Now is a bit sketchy...assume only one basic block in CFG + let block_node = hugr + .children(cfg_node) + .find(|n| matches!(hugr.get_optype(*n), OpType::BasicBlock(_))) + .expect("CFG contains no basic block."); + DescendantsGraph::try_new(&hugr, block_node).unwrap() + } else { + panic!("HUGR Root Op type {root_op_tag:?} not supported"); + }; + + let phir = circuit_to_phir(&circ).unwrap(); + + serde_json::to_writer(File::create(output).unwrap(), &phir).unwrap(); +} diff --git a/hugr2phir/src/normalize.rs b/hugr2phir/src/normalize.rs new file mode 100644 index 00000000..d5e6c047 --- /dev/null +++ b/hugr2phir/src/normalize.rs @@ -0,0 +1,118 @@ +use hugr::builder::Dataflow; +use hugr::builder::DataflowHugr; +use hugr::HugrView; +use hugr::SimpleReplacement; + +use hugr::hugr::hugrmut::HugrMut; +use hugr::hugr::views::SiblingSubgraph; +use hugr::ops::OpType; +use itertools::Itertools; +use tket2::extension::REGISTRY; + +use hugr::ops::LeafOp; + +use hugr::types::FunctionType; + +use hugr::builder::DFGBuilder; + +use hugr::types::TypeRow; + +use hugr::Hugr; + +fn identity_dfg(type_combination: TypeRow) -> Hugr { + let identity_build = DFGBuilder::new(FunctionType::new( + type_combination.clone(), + type_combination, + )) + .unwrap(); + let inputs = identity_build.input_wires(); + identity_build + .finish_hugr_with_outputs(inputs, ®ISTRY) + .unwrap() +} + +fn find_make_unmake(hugr: &impl HugrView) -> impl Iterator + '_ { + hugr.nodes().filter_map(|n| { + let op = hugr.get_optype(n); + + let OpType::LeafOp(LeafOp::MakeTuple { tys }) = op else { + return None; + }; + + let Ok(neighbour) = hugr.output_neighbours(n).exactly_one() else { + return None; + }; + + let OpType::LeafOp(LeafOp::UnpackTuple { .. }) = hugr.get_optype(neighbour) else { + return None; + }; + + let sibling_graph = SiblingSubgraph::try_from_nodes([n, neighbour], hugr) + .expect("Make unmake should be valid subgraph."); + + let replacement = identity_dfg(tys.clone()); + sibling_graph + .create_simple_replacement(hugr, replacement) + .ok() + }) +} + +/// Remove any pairs of MakeTuple immediately followed by UnpackTuple (an +/// identity operation) +pub(crate) fn remove_identity_tuples(circ: &mut impl HugrMut) { + let rewrites: Vec<_> = find_make_unmake(circ).collect(); + // should be able to apply all in parallel unless there are copies... + + for rw in rewrites { + circ.apply_rewrite(rw).unwrap(); + } +} + +#[cfg(test)] +mod test { + use super::*; + use hugr::extension::prelude::BOOL_T; + use hugr::extension::prelude::QB_T; + use hugr::type_row; + use hugr::HugrView; + + fn make_unmake_tuple(type_combination: TypeRow) -> Hugr { + let mut b = DFGBuilder::new(FunctionType::new( + type_combination.clone(), + type_combination.clone(), + )) + .unwrap(); + let input_wires = b.input_wires(); + + let tuple = b + .add_dataflow_op( + LeafOp::MakeTuple { + tys: type_combination.clone(), + }, + input_wires, + ) + .unwrap(); + + let unpacked = b + .add_dataflow_op( + LeafOp::UnpackTuple { + tys: type_combination, + }, + tuple.outputs(), + ) + .unwrap(); + + b.finish_hugr_with_outputs(unpacked.outputs(), ®ISTRY) + .unwrap() + } + #[test] + fn test_remove_id_tuple() { + let mut h = make_unmake_tuple(type_row![QB_T, BOOL_T]); + + assert_eq!(h.node_count(), 5); + + remove_identity_tuples(&mut h); + + assert_eq!(h.node_count(), 3); + } +} diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 16c20ea0..9bb4ec59 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -11,7 +11,11 @@ use hugr::extension::{ExtensionId, ExtensionRegistry, SignatureError}; use hugr::hugr::IdentList; use hugr::ops::custom::{ExternalOp, OpaqueOp}; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::extension as float_extension; +use hugr::std_extensions::arithmetic::{ + float_types::extension as float_extension, int_ops::extension as int_ops_extension, + int_types::extension as int_types_extension, +}; + use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam}; use hugr::types::{CustomType, FunctionType, Type, TypeBound}; use hugr::Extension; @@ -70,6 +74,9 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::from([ PRELUDE.clone(), T2EXTENSION.clone(), float_extension(), + int_ops_extension(), + int_types_extension(), + ]); diff --git a/tket2/src/lib.rs b/tket2/src/lib.rs index 6c6db224..f48dd7a9 100644 --- a/tket2/src/lib.rs +++ b/tket2/src/lib.rs @@ -18,6 +18,7 @@ pub mod rewrite; #[cfg(feature = "portmatching")] pub mod portmatching; +pub mod phir; mod utils; pub use circuit::Circuit; diff --git a/tket2/src/phir.rs b/tket2/src/phir.rs new file mode 100644 index 00000000..d83f2247 --- /dev/null +++ b/tket2/src/phir.rs @@ -0,0 +1,7 @@ +//! Rust struct for PHIR and conversion from HUGR. + +mod convert; +mod model; + +pub use convert::circuit_to_phir; +pub use model::PHIRModel; diff --git a/tket2/src/phir/convert.rs b/tket2/src/phir/convert.rs new file mode 100644 index 00000000..0de240ec --- /dev/null +++ b/tket2/src/phir/convert.rs @@ -0,0 +1,581 @@ +use std::{collections::HashMap, str::FromStr}; + +use super::model::{Bit, PHIRModel}; +use crate::{ + circuit::Command, + phir::model::{ + COp, COpArg, CVarDefine, CopReturn, Data, ExportVar, Metadata, QOp, QOpArg, QVarDefine, + }, + Circuit, T2Op, +}; +use derive_more::From; +use hugr::{ + extension::prelude::{BOOL_T, QB_T}, + ops::{custom::ExternalOp, Const, LeafOp, OpTag, OpTrait, OpType}, + std_extensions::arithmetic::{ + float_types::ConstF64, + int_types::{ConstIntS, INT_TYPES}, + }, + types::{EdgeKind, TypeEnum}, + values::{CustomConst, PrimValue, Value}, + CircuitUnit, HugrView, Node, Wire, +}; +use itertools::{Either, Itertools}; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +use thiserror::Error; + +const QUBIT_ID: &str = "q"; +fn q_arg(index: usize) -> (String, u64) { + (QUBIT_ID.to_string(), index as u64) +} +/// Convert Circuit-like HUGR to PHIR. +pub fn circuit_to_phir(circ: &impl Circuit) -> Result { + let mut ph = PHIRModel::new(); + + // Define quantum and classical variables for inputs + let init_arg_map = init_input_variables(circ, &mut ph); + + // add commands to Phir + let final_arg_map = add_commands(circ, &mut ph, init_arg_map)?; + + // get all classical output wires + export_outputs(circ, final_arg_map, &mut ph); + + // TODO: Add DFG as SeqBlock + + // TODO: Add conditional as IfBlock + + // TODO: Add wasm calls + + Ok(ph) +} + +/// Assign variables to output values and export them. +fn export_outputs(circ: &impl Circuit, arg_map: HashMap, ph: &mut PHIRModel) { + let c_out_args = in_neighbour_wires(circ, circ.output()) + .filter_map(|wire| get_export_cop(circ, wire, &arg_map)); + let mut temp_var_count = 0..; + + for variable in c_out_args { + let variable = match variable { + COpArg::Sym(s) => s, + COpArg::Bit((s, _)) => s, + _ => { + // assign the value to a newly defined variable so it can be exported + let out_var_name = format!("__temp{}", temp_var_count.next().unwrap()); + + // TODO expand to 64? + let def = def_int_var(out_var_name.clone(), 32); + let assign = crate::phir::model::Op { + op_enum: COp { + cop: "=".to_string(), + args: vec![variable], + returns: Some(vec![CopReturn::Sym(out_var_name.clone())]), + } + .into(), + metadata: Metadata::default(), + }; + ph.append_op(def); + ph.append_op(assign); + + out_var_name + } + }; + let export = Data { + data: ExportVar { + variables: vec![variable], + to: None, + } + .into(), + metadata: Metadata::default(), + }; + ph.append_op(export); + } +} + +fn get_export_cop( + circ: &impl HugrView, + wire: Wire, + arg_map: &HashMap, +) -> Option { + // we only care about certain copyable Value out edges. + let Some(EdgeKind::Value(t)) = circ.get_optype(wire.node()).port_kind(wire.source()) else { + return None; + }; + // Ignore sums and tuples + if t == BOOL_T || !matches!(t.as_type_enum(), TypeEnum::Sum(_) | TypeEnum::Tuple(_)) { + get_c_op(arg_map, circ, wire) + } else { + None + } +} + +/// Add quantum (acting on qubits) commands ot PHIR +fn add_commands( + circ: &impl Circuit, + ph: &mut PHIRModel, + mut arg_map: HashMap, +) -> Result, &'static str> { + // Add commands + for com in circ.commands() { + let optype = com.optype(); + + let qop = match t2op_name(optype) { + Ok(PhirOp::QOp(s)) => s.to_string(), + Ok(_) => continue, + Err(s) => return Err(s), + }; + let mut angles = vec![]; + let args: Vec = com + .inputs() + .filter_map(|(u, _, _)| match u { + CircuitUnit::Wire(w) => { + // TODO: constant folding angles + + let angle: ConstF64 = get_value(get_const(w, circ).unwrap()) + .expect("Only constant angles supported as QOP inputs."); + angles.push(angle.value()); + None + } + CircuitUnit::Linear(i) => Some(q_arg(i)), + }) + .collect(); + + let args: Vec = if args.len() == 1 { + let [arg]: [Bit; 1] = args.try_into().unwrap(); + vec![QOpArg::Bit(arg)] + } else { + vec![QOpArg::ListBit(args)] + }; + + let returns = if qop == "Measure" { + let (bit, wire) = measure_out_arg(com); + let def = def_int_var(bit.0.clone(), 1); + + ph.insert_op(0, def); + arg_map.insert(wire, COpArg::Sym(bit.0.clone())); + + Some(vec![bit]) + } else { + None + }; + let phir_op = crate::phir::model::Op { + op_enum: QOp { + qop, + args, + returns, + angles: (!angles.is_empty()).then_some((angles.clone(), "rad".to_string())), + } + .into(), + // TODO once PECOS no longer requires angles in the metadata + metadata: match angles.len() { + 0 => Metadata::default(), + 1 => Metadata::from_iter([("angle".to_string(), angles[0].into())]), + _ => Metadata::from_iter([("angles".to_string(), angles.into())]), + }, + }; + + ph.append_op(phir_op); + } + Ok(arg_map) +} + +/// Initialize phir variables for input qubits/integers +/// Returning a map from the input wire to the corresponding PHIR variable. +fn init_input_variables(circ: &impl Circuit, ph: &mut PHIRModel) -> HashMap { + let mut qubit_count = 0; + let mut input_int_count = 0; + let arg_map: HashMap = circ + .units() + .filter_map(|(cu, _, t)| match (cu, t) { + (CircuitUnit::Wire(wire), t) if t == INT_TYPES[6] => { + let variable = format!("i{input_int_count}"); + let cvar_def: Data = Data { + data: CVarDefine { + data_type: "i64".to_string(), + variable: variable.clone(), + size: None, + } + .into(), + metadata: Metadata::default(), + }; + input_int_count += 1; + ph.append_op(cvar_def); + Some((wire, COpArg::Sym(variable))) + } + (CircuitUnit::Linear(_), t) if t == QB_T => { + qubit_count += 1; + None + } + _ => unimplemented!("Non-int64 input wires not supported"), + }) + .collect(); + + let qvar_def: Data = Data { + data: QVarDefine { + data_type: Some("qubits".to_string()), + variable: "q".to_string(), + size: qubit_count, + } + .into(), + metadata: Metadata::default(), + }; + ph.append_op(qvar_def); + arg_map +} + +/// Iterate all Value wires inbound on a node +fn in_neighbour_wires(circ: &impl HugrView, node: Node) -> impl Iterator + '_ { + let node_type = circ.get_optype(node); + circ.node_inputs(node) + .filter(|port| { + node_type.port_kind(*port).is_some_and(|k| match k { + EdgeKind::Value(t) => t.copyable(), + _ => false, + }) + }) + .flat_map(move |port| circ.linked_ports(node, port)) + .map(|(n, p)| Wire::new(n, p)) +} + +/// Recursively walk back from a classical wire, turning each operation that +/// generates it in to a PHIR classical op (COp). +// TODO: Should be made non-recursive for scaling. +fn get_c_op(arg_map: &HashMap, circ: &impl HugrView, wire: Wire) -> Option { + // the wire is a known variable, return it + if let Some(cop) = arg_map.get(&wire) { + return Some(cop.clone()); + } + + // the wire comes from a constant + if let Some(c) = get_const(wire, circ) { + return Some(if c == Const::true_val() { + COpArg::IntValue(1) + } else if c == Const::false_val() { + COpArg::IntValue(0) + } else if let Some(int) = get_value::(c) { + COpArg::IntValue(int.value()) + } else { + panic!("Unknown constant."); + }); + } + + // the wire is a known classical operation + if let Ok(PhirOp::Cop(cop)) = t2op_name(circ.get_optype(wire.node())) { + let mut args: Vec = in_neighbour_wires(circ, wire.node()) + .flat_map(|prev_wire| get_c_op(arg_map, circ, prev_wire)) + .collect(); + Some(if cop == PhirCop::FromBool { + // cast is an identity + args.remove(0) + } else { + COpArg::COp(COp { + cop: cop.phir_name().to_string(), + args, + returns: None, + }) + }) + } else { + // don't know how to generate this wire in PHIR + None + } +} + +/// For a measure operation, generate a new variable name for it and +/// return the measurement result wire. +fn measure_out_arg(com: Command<'_, impl Circuit>) -> (Bit, Wire) { + let (wires, qb_indices): (Vec<_>, Vec<_>) = com.outputs().partition_map(|(c, _, _)| match c { + CircuitUnit::Wire(w) => Either::Left(w), + CircuitUnit::Linear(i) => Either::Right(i), + }); + + let [measure_wire]: [Wire; 1] = wires + .try_into() + .expect("Should only be one classical wire from measure."); + let [qb_index]: [usize; 1] = qb_indices + .try_into() + .expect("Should only be one quantum wire from measure."); + + // variable name marked with qubit index being measured + let variable = format!("c{}", qb_index); + + // declare a width-1 register per measurement + // TODO what if qubit measured multiple times? + + let arg = (variable.clone(), 0); + + (arg, measure_wire) +} + +/// Generate PHIR integer variable definition +fn def_int_var(variable: String, size: u64) -> Data { + Data { + data: CVarDefine { + data_type: "i64".to_string(), + variable, + size: Some(size), + } + .into(), + metadata: Metadata::default(), + } +} + +//. If a Wire comes from a LoadConst, return the original Const op holding the value. +fn get_const(wire: Wire, circ: &impl HugrView) -> Option { + if circ.get_optype(wire.node()).tag() != OpTag::LoadConst { + return None; + } + + circ.input_neighbours(wire.node()).find_map(|n| { + let const_op = circ.get_optype(n); + + const_op.clone().try_into().ok() + }) +} + +/// For a Const holding a CustomConst, extract the CustomConst by downcasting. +fn get_value(op: Const) -> Option { + // impl TryFrom for T in Hugr crate + if let Value::Prim { + val: PrimValue::Extension { c: (custom,) }, + } = op.value() + { + let c: T = *(custom.clone()).downcast().ok()?; + Some(c) + } else { + None + } +} + +#[derive(From)] +enum OpConvertError { + Skip, + Other(&'static str), +} + +enum PhirOp { + QOp(&'static str), + Cop(PhirCop), + Skip, +} + +/// Get the PHIR name for a quantum operation +fn t2op_name(op: &OpType) -> Result { + let err = Err("Unknown op"); + if let OpTag::Const | OpTag::LoadConst = op.tag() { + return Ok(PhirOp::Skip); + } + let OpType::LeafOp(leaf) = op else { + return err; + }; + + if let Ok(t2op) = leaf.try_into() { + // https://github.com/CQCL/phir/blob/main/phir_spec_qasm.md + Ok(PhirOp::QOp(match t2op { + T2Op::H => "H", + T2Op::CX => "CX", + T2Op::T => "T", + T2Op::S => "SZ", + T2Op::X => "X", + T2Op::Y => "Y", + T2Op::Z => "Z", + T2Op::Tdg => "Tdg", + T2Op::Sdg => "SZdg", + T2Op::ZZMax => "SZZ", + T2Op::Measure => "Measure", + T2Op::RzF64 => "RZ", + T2Op::RxF64 => "RX", + T2Op::PhasedX => "R1XY", + T2Op::ZZPhase => "RZZ", + T2Op::CZ => "CZ", + T2Op::AngleAdd | T2Op::TK1 => return err, + })) + } else if let Ok(phir_cop) = leaf.try_into() { + Ok(PhirOp::Cop(phir_cop)) + } else { + match leaf { + LeafOp::Tag { .. } | LeafOp::MakeTuple { .. } | LeafOp::UnpackTuple { .. } => { + Ok(PhirOp::Skip) + } + _ => err, + } + } +} + +#[derive( + Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, EnumIter, IntoStaticStr, EnumString, +)] +enum PhirCop { + #[strum(serialize = "iadd")] + Add, + #[strum(serialize = "isub")] + Sub, + #[strum(serialize = "imul")] + Mul, + #[strum(serialize = "idiv")] + Div, + #[strum(serialize = "imod_s")] + Mod, + #[strum(serialize = "ieq")] + Eq, + #[strum(serialize = "ine")] + Neq, + #[strum(serialize = "ilt_s")] + Gt, + #[strum(serialize = "igt_s")] + Lt, + #[strum(serialize = "ige_s")] + Ge, + #[strum(serialize = "ile_s")] + Le, + #[strum(serialize = "iand")] + And, + #[strum(serialize = "ior")] + Or, + #[strum(serialize = "ixor")] + Xor, + #[strum(serialize = "inot")] + Not, + #[strum(serialize = "ishl")] + Lsh, + #[strum(serialize = "ishr")] + Rsh, + #[strum(serialize = "ifrombool")] + FromBool, +} + +impl PhirCop { + fn phir_name(&self) -> &'static str { + match self { + PhirCop::Add => "+", + PhirCop::Sub => "-", + PhirCop::Mul => "*", + PhirCop::Div => "/", + PhirCop::Mod => "%", + PhirCop::Eq => "==", + PhirCop::Neq => "!=", + PhirCop::Gt => ">", + PhirCop::Lt => "<", + PhirCop::Ge => ">=", + PhirCop::Le => "<=", + PhirCop::And => "&", + PhirCop::Or => "|", + PhirCop::Xor => "^", + PhirCop::Not => "~", + PhirCop::Lsh => "<<", + PhirCop::Rsh => ">>", + PhirCop::FromBool => panic!("{:?} not a valid phir op.", self), + } + } +} + +#[derive(Error, Debug, Clone)] +#[error("Not a Phir classical op.")] +struct NotPhirCop; + +impl TryFrom for PhirCop { + type Error = NotPhirCop; + + fn try_from(op: OpType) -> Result { + Self::try_from(&op) + } +} + +impl TryFrom<&OpType> for PhirCop { + type Error = NotPhirCop; + + fn try_from(op: &OpType) -> Result { + let OpType::LeafOp(leaf) = op else { + return Err(NotPhirCop); + }; + leaf.try_into() + } +} + +impl TryFrom<&LeafOp> for PhirCop { + type Error = NotPhirCop; + + fn try_from(op: &LeafOp) -> Result { + match op { + LeafOp::CustomOp(b) => { + let name = match b.as_ref() { + ExternalOp::Extension(e) => e.def().name(), + ExternalOp::Opaque(o) => o.name(), + }; + + PhirCop::from_str(name).map_err(|_| NotPhirCop) + } + _ => Err(NotPhirCop), + } + } +} + +impl TryFrom for PhirCop { + type Error = NotPhirCop; + + fn try_from(op: LeafOp) -> Result { + Self::try_from(&op) + } +} +#[cfg(test)] +mod test { + + use hugr::{ + builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, + extension::{prelude::BOOL_T, ExtensionSet}, + std_extensions::arithmetic::float_types::{ConstF64, EXTENSION_ID}, + types::FunctionType, + Hugr, + }; + use rstest::{fixture, rstest}; + + use crate::extension::REGISTRY; + + use super::*; + + #[fixture] + // A commutation forward exists but depth doesn't change + fn sample() -> Hugr { + { + let num_qubits = 2; + let num_measured_bools = 2; + let inputs = vec![QB_T; num_qubits]; + let outputs = [inputs.clone(), vec![BOOL_T; num_measured_bools]].concat(); + + let mut h = DFGBuilder::new(FunctionType::new(inputs, outputs)).unwrap(); + let angle_const = ConstF64::new(1.2); + + let angle = h + .add_load_const(angle_const.into(), ExtensionSet::from_iter([EXTENSION_ID])) + .unwrap(); + let qbs = h.input_wires(); + + let mut circ = h.as_circuit(qbs.into_iter().collect()); + + let o: Result, BuildError> = (|| { + circ.append(T2Op::H, [1])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::Z, [0])?; + circ.append(T2Op::X, [1])?; + circ.append_and_consume( + T2Op::RzF64, + [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], + )?; + let mut c0 = circ.append_with_outputs(T2Op::Measure, [0])?; + let c1 = circ.append_with_outputs(T2Op::Measure, [1])?; + c0.extend(c1); + Ok(c0) + })(); + let o = o.unwrap(); + + let qbs = circ.finish(); + h.finish_hugr_with_outputs([qbs, o].concat(), ®ISTRY) + } + .unwrap() + } + #[rstest] + fn test_sample(sample: Hugr) { + let ph = circuit_to_phir(&sample).unwrap(); + assert_eq!(ph.num_ops(), 12); + } +} diff --git a/tket2/src/phir/model.rs b/tket2/src/phir/model.rs new file mode 100644 index 00000000..25c94edb --- /dev/null +++ b/tket2/src/phir/model.rs @@ -0,0 +1,316 @@ +// PHIR JSON schema: https://github.com/CQCL/phir/blob/main/schema.json + +use derive_more::From; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +pub(super) type Metadata = Map; + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct Data { + #[serde(flatten)] + pub(super) data: DataEnum, + #[serde(default)] + pub(super) metadata: Metadata, +} +fn default_cvar_def_data() -> String { + "i64".to_string() +} + +fn default_qvar_def_data() -> Option { + Some("qubits".to_string()) +} +fn default_ffcall_cop() -> String { + "ffcall".to_string() +} + +fn default_format() -> String { + "PHIR/JSON".to_string() +} + +fn default_version() -> String { + "0.1.0".to_string() +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct CVarDefine { + #[serde(default = "default_cvar_def_data")] + pub(super) data_type: String, + pub(super) variable: Sym, + pub(super) size: Option, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct QVarDefine { + #[serde(default = "default_qvar_def_data")] + pub(super) data_type: Option, + pub(super) variable: Sym, + pub(super) size: u64, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct ExportVar { + pub(super) variables: Vec, + pub(super) to: Option>, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, From)] +#[serde(tag = "data")] +pub(super) enum DataEnum { + #[serde(rename = "cvar_define")] + CVarDefine(CVarDefine), + #[serde(rename = "qvar_define")] + QVarDefine(QVarDefine), + #[serde(rename = "cvar_export")] + ExportVar(ExportVar), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct Op { + #[serde(flatten)] + pub op_enum: OpEnum, + #[serde(default)] + pub metadata: Metadata, +} +pub type Bit = (String, u64); +pub type Sym = String; +#[derive(Serialize, Deserialize, Debug, PartialEq, From, Clone)] +#[serde(untagged)] +pub(super) enum COpArg { + Sym(Sym), + IntValue(i64), + Bit(Bit), + // Variadic(Vec), + COp(COp), + // Other(Value), +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, From)] +#[serde(untagged)] +pub(super) enum CopReturn { + Sym(Sym), + Bit(Bit), +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub(super) struct COp { + pub cop: String, + pub args: Vec, + pub returns: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, From)] +#[serde(untagged)] +pub(super) enum QOpArg { + ListBit(Vec), + Bit(Bit), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct QOp { + pub qop: String, + pub args: Vec, + pub returns: Option>, + pub angles: Option<(Vec, String)>, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct FFCall { + #[serde(default = "default_ffcall_cop")] + pub cop: String, + pub function: String, + pub args: Vec, + pub returns: Option>, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct MOp { + pub mop: String, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, From)] +#[serde(untagged)] +pub(super) enum OpEnum { + Qop(QOp), + Cop(COp), + FFCall(FFCall), + Mop(MOp), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct Block { + #[serde(flatten)] + pub block_enum: BlockEnum, + pub metadata: Option>, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct If { + pub(super) condition: COp, + pub(super) true_branch: Vec, + pub(super) false_branch: Option>, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct Seq { + pub(super) ops: Vec, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, From)] +#[serde(tag = "block")] +pub(super) enum BlockEnum { + #[serde(rename = "sequence")] + Seq(Seq), + #[serde(rename = "if")] + If(If), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, From)] +#[serde(untagged)] +pub(super) enum BlockElems { + Op(Op), + Block(Block), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub(super) struct Comment { + #[serde(rename = "//")] + c: String, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, From)] +#[serde(untagged)] +pub(super) enum OpListElems { + Comment(Comment), + Op(Op), + Block(Block), + Data(Data), +} + +/// Rust encapsulation of [PHIR](https://github.com/CQCL/phir) spec. +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct PHIRModel { + #[serde(default = "default_format")] + format: String, + #[serde(default = "default_version")] + version: String, + #[serde(default)] + metadata: Metadata, + ops: Vec, +} + +impl PHIRModel { + /// Creates a new [`PHIRModel`]. + pub fn new() -> Self { + Self { + format: default_format(), + version: default_version(), + metadata: Map::new(), + ops: vec![], + } + } + + /// . + pub(super) fn append_op(&mut self, op: impl Into) { + self.ops.push(op.into()); + } + + pub(super) fn insert_op(&mut self, index: usize, op: impl Into) { + self.ops.insert(index, op.into()); + } + + /// Returns the number of ops of this [`PHIRModel`]. + pub fn num_ops(&self) -> usize { + self.ops.len() + } +} + +impl Default for PHIRModel { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod test { + use std::{fs::File, io::BufReader}; + + use super::*; + #[test] + fn test_data() { + let example = r#" + { + "data": "cvar_define", + "data_type": "i64", + "variable": "a", + "size": 32 + } + "#; + + let _: Data = serde_json::from_str(example).unwrap(); + } + + #[test] + fn test_op() { + let example = r#" + { + "qop": "Measure", + "args": [["q", 0], ["q", 1]], + "returns": [["m", 0], ["m", 1]] + } + "#; + + let _: Op = serde_json::from_str(example).unwrap(); + } + + #[test] + fn test_block() { + let example = r#" + { + "block": "if", + "condition": {"cop": "==", "args": ["m", 1]}, + "true_branch": [ + { + "cop": "=", + "args": [ + { + "cop": "|", + "args": [ + {"cop": "^", "args": [["c", 2], "d"]}, + { + "cop": "+", + "args": [ + {"cop": "-", "args": ["e", 2]}, + {"cop": "&", "args": ["f", "g"]} + ] + } + ] + } + ], + "returns": ["a"] + } + ] + } + "#; + + let _: Block = serde_json::from_str(example).unwrap(); + } + #[test] + fn test_comment() { + let example = r#"{"//": "measure q -> m;"}"#; + + let _: Comment = serde_json::from_str(example).unwrap(); + } + + #[test] + fn test_all() { + let reader = BufReader::new(File::open("./src/phir/test.json").unwrap()); + let p: PHIRModel = serde_json::from_reader(reader).unwrap(); + assert_eq!(p.ops.len(), 50); + let s = serde_json::to_string(&p).unwrap(); + + let p2: PHIRModel = serde_json::from_str(&s).unwrap(); + + assert_eq!(p, p2); + } +} diff --git a/tket2/src/phir/test.json b/tket2/src/phir/test.json new file mode 100644 index 00000000..b159145f --- /dev/null +++ b/tket2/src/phir/test.json @@ -0,0 +1,187 @@ +{ + "format": "PHIR/JSON", + "version": "0.1.0", + "metadata": { + "program_name": "example_prog", + "description": "Program showing off PHIR", + "num_qubits": 10 + }, + "ops": [ + {"//": "qreg q[2];"}, + {"//": "qreg w[3];"}, + {"//": "qreg d[5];"}, + { + "data": "qvar_define", + "data_type": "qubits", + "variable": "q", + "size": 2 + }, + { + "data": "qvar_define", + "data_type": "qubits", + "variable": "w", + "size": 3 + }, + { + "data": "qvar_define", + "data_type": "qubits", + "variable": "d", + "size": 5 + }, + + {"//": "creg m[2];"}, + {"//": "creg a[32];"}, + {"//": "creg b[32];"}, + {"//": "creg c[12];"}, + {"//": "creg d[10];"}, + {"//": "creg e[30];"}, + {"//": "creg f[5];"}, + {"//": "creg g[32];"}, + {"data": "cvar_define", "data_type": "i64", "variable": "m", "size": 2}, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "a", + "size": 32 + }, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "b", + "size": 32 + }, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "c", + "size": 12 + }, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "d", + "size": 10 + }, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "e", + "size": 30 + }, + {"data": "cvar_define", "data_type": "i64", "variable": "f", "size": 5}, + { + "data": "cvar_define", + "data_type": "i64", + "variable": "g", + "size": 32 + }, + + {"//": "h q[0];"}, + {"qop": "H", "args": [["q", 0]]}, + + {"//": "CX q[0], q[1];"}, + {"qop": "CX", "args": [[["q", 0], ["q", 1]]]}, + + {"//": "measure q -> m;"}, + { + "qop": "Measure", + "args": [["q", 0], ["q", 1]], + "returns": [["m", 0], ["m", 1]] + }, + + {"//": "b = 5;"}, + {"cop": "=", "args": [5], "returns": ["b"]}, + + {"//": "c = 3;"}, + {"cop": "=", "args": [3], "returns": ["c"]}, + + {"//": "a[0] = add(b, c); // FF call, e.g., Wasm call"}, + { + "cop": "ffcall", + "function": "add", + "args": ["b", "c"], + "returns": [["a", 0]] + }, + + {"//": "if(m==1) a = (c[2] ^ d) | (e - 2 + (f & g));"}, + { + "block": "if", + "condition": {"cop": "==", "args": ["m", 1]}, + "true_branch": [ + { + "cop": "=", + "args": [ + { + "cop": "|", + "args": [ + {"cop": "^", "args": [["c", 2], "d"]}, + { + "cop": "+", + "args": [ + {"cop": "-", "args": ["e", 2]}, + {"cop": "&", "args": ["f", "g"]} + ] + } + ] + } + ], + "returns": ["a"] + } + ] + }, + + {"//": "if(m==2) sub(d, e); // Conditioned void FF call. Void calls are assumed to update a separate classical state running asynchronously/in parallel."}, + { + "block": "if", + "condition": {"cop": "==", "args": ["m", 2]}, + "true_branch": [ + {"cop": "ffcall", "function": "sub", "args": ["d", "e"]} + ] + }, + + {"//": "if(a > 2) c = 7;"}, + {"//": "if(a > 2) x w[0];"}, + {"//": "if(a > 2) h w[1];"}, + {"//": "if(a > 2) CX w[1], w[2];"}, + {"//": "if(a > 2) measure w[1] -> g[0];"}, + {"//": "if(a > 2) measure w[2] -> g[1];"}, + { + "block": "if", + "condition": {"cop": ">", "args": ["a", 2]}, + "true_branch": [ + {"cop": "=", "args": [7], "returns": ["c"]}, + {"qop": "X", "args": [["w", 0]]}, + {"qop": "H", "args": [["w", 1]]}, + {"qop": "CX", "args": [[["w", 1], ["w", 2]]]}, + { + "qop": "Measure", + "args": [["w", 1], ["w", 2]], + "returns": [["g", 0], ["g", 1]] + } + ] + }, + + {"//": "if(a[3]==1) h d;"}, + { + "block": "if", + "condition": {"cop": "==", "args": [["a", 3], 1]}, + "true_branch": [ + { + "qop": "H", + "args": [["d", 0], ["d", 1], ["d", 2], ["d", 3], ["d", 4]] + } + ] + }, + + {"//": "measure d -> f;"}, + { + "qop": "Measure", + "args": [["d", 0], ["d", 1], ["d", 2], ["d", 3], ["d", 4]], + "returns": [["f", 0], ["f", 1], ["f", 2], ["f", 3], ["f", 4]] + }, + { + "data": "cvar_export", + "variables": ["m", "a", "b", "c", "d", "e", "f", "g"] + } + ] +} \ No newline at end of file diff --git a/tket2/src/utils.rs b/tket2/src/utils.rs index 22456624..a4f06795 100644 --- a/tket2/src/utils.rs +++ b/tket2/src/utils.rs @@ -1,5 +1,6 @@ //! Utility functions for the library. +use hugr::extension::prelude::BOOL_T; use hugr::extension::PRELUDE_REGISTRY; use hugr::types::{Type, TypeBound}; use hugr::{ @@ -19,17 +20,32 @@ pub(crate) fn build_simple_circuit( num_qubits: usize, f: impl FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, ) -> Result { - let qb_row = vec![QB_T; num_qubits]; - let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?; + build_simple_measure_circuit(num_qubits, 0, |c| { + f(c); + Ok(vec![]) + }) +} + +// utility for building simple qubit-only circuits with some measure outputs. +#[allow(unused)] +pub(crate) fn build_simple_measure_circuit( + num_qubits: usize, + num_measured_bools: usize, + f: impl FnOnce(&mut CircuitBuilder>) -> Result, BuildError>, +) -> Result { + let inputs = vec![QB_T; num_qubits]; + let outputs = [inputs.clone(), vec![BOOL_T; num_measured_bools]].concat(); + + let mut h = DFGBuilder::new(FunctionType::new(inputs, outputs))?; let qbs = h.input_wires(); let mut circ = h.as_circuit(qbs.into_iter().collect()); - f(&mut circ)?; + let o = f(&mut circ)?; let qbs = circ.finish(); - h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY) + h.finish_hugr_with_outputs([qbs, o].concat(), &PRELUDE_REGISTRY) } // Test only utils