diff --git a/Cargo.toml b/Cargo.toml index 8fa4e547..8bcf39ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ license-file = "LICENCE" [workspace.dependencies] tket2 = { path = "./tket2" } -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "d0499ad" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "b256c2b" } portgraph = { version = "0.10" } pyo3 = { version = "0.20" } itertools = { version = "0.11.0" } diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index a89a59bf..e388c56c 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -5,19 +5,23 @@ use std::collections::HashMap; use super::json::op::JsonOp; -use crate::ops::EXTENSION as T2EXTENSION; +use crate::ops::load_all_ops; +use crate::T2Op; use hugr::extension::prelude::PRELUDE; use hugr::extension::{ExtensionId, ExtensionRegistry, SignatureError}; use hugr::hugr::IdentList; use hugr::ops::custom::{ExternalOp, OpaqueOp}; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::extension as float_extension; +use hugr::std_extensions::arithmetic::float_types::{extension as float_extension, FLOAT64_TYPE}; use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam}; use hugr::types::{CustomType, FunctionType, Type, TypeBound}; -use hugr::Extension; +use hugr::{type_row, Extension}; use lazy_static::lazy_static; use smol_str::SmolStr; +/// Definition for Angle ops and types. +pub mod angle; + /// The ID of the TKET1 extension. pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1"); @@ -68,7 +72,7 @@ pub static ref LINEAR_BIT: Type = { pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::from([ TKET1_EXTENSION.clone(), PRELUDE.clone(), - T2EXTENSION.clone(), + TKET2_EXTENSION.clone(), float_extension(), ]); @@ -125,3 +129,49 @@ fn json_op_signature(args: &[TypeArg]) -> Result { let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors! Ok(op.signature()) } + +/// Angle type with given log denominator. +pub fn angle_custom_type(log_denom: u8) -> CustomType { + angle::angle_custom_type(&TKET2_EXTENSION, angle::type_arg(log_denom)) +} + +/// Name of tket 2 extension. +pub const TKET2_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2"); + +/// The name of the symbolic expression opaque type arg. +pub const SYM_EXPR_NAME: SmolStr = SmolStr::new_inline("SymExpr"); + +/// The name of the symbolic expression opaque type arg. +pub const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_float"); + +lazy_static! { +/// The type of the symbolic expression opaque type arg. +pub static ref SYM_EXPR_T: CustomType = + TKET2_EXTENSION.get_type(&SYM_EXPR_NAME).unwrap().instantiate([]).unwrap(); + +/// The extension definition for TKET2 ops and types. +pub static ref TKET2_EXTENSION: Extension = { + let mut e = Extension::new(TKET2_EXTENSION_ID); + load_all_ops::(&mut e).expect("add fail"); + + let sym_expr_opdef = e.add_type( + SYM_EXPR_NAME, + vec![], + "Symbolic expression.".into(), + TypeBound::Eq.into(), + ) + .unwrap(); + let sym_expr_param = TypeParam::Opaque(sym_expr_opdef.instantiate([]).unwrap()); + + e.add_op_custom_sig_simple( + SYM_OP_ID, + "Store a sympy expression that can be evaluated to a float.".to_string(), + vec![sym_expr_param], + |_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])), + ) + .unwrap(); + + angle::add_to_extension(&mut e); + e +}; +} diff --git a/tket2/src/extension/angle.rs b/tket2/src/extension/angle.rs new file mode 100644 index 00000000..31292728 --- /dev/null +++ b/tket2/src/extension/angle.rs @@ -0,0 +1,317 @@ +use std::{cmp::max, num::NonZeroU64}; + +use hugr::{ + extension::{prelude::ERROR_TYPE, ExtensionRegistry, SignatureError, TypeDef, PRELUDE}, + types::{ + type_param::{TypeArgError, TypeParam}, + ConstTypeError, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeArg, + TypeBound, + }, + values::CustomConst, + Extension, +}; +use itertools::Itertools; +use smol_str::SmolStr; +use std::f64::consts::TAU; + +/// Identifier for the angle type. +const ANGLE_TYPE_ID: SmolStr = SmolStr::new_inline("angle"); + +pub(super) fn angle_custom_type(extension: &Extension, log_denom_arg: TypeArg) -> CustomType { + angle_def(extension).instantiate([log_denom_arg]).unwrap() +} + +fn angle_type(log_denom: u8) -> Type { + Type::new_extension(super::angle_custom_type(log_denom)) +} + +/// The largest permitted log-denominator. +pub const LOG_DENOM_MAX: u8 = 53; + +const fn is_valid_log_denom(n: u8) -> bool { + n <= LOG_DENOM_MAX +} + +/// Type parameter for the log-denominator of an angle. +pub const LOG_DENOM_TYPE_PARAM: TypeParam = + TypeParam::bounded_nat(NonZeroU64::MIN.saturating_add(LOG_DENOM_MAX as u64)); + +/// Get the log-denominator of the specified type argument or error if the argument is invalid. +fn get_log_denom(arg: &TypeArg) -> Result { + match arg { + TypeArg::BoundedNat { n } if is_valid_log_denom(*n as u8) => Ok(*n as u8), + _ => Err(TypeArgError::TypeMismatch { + arg: arg.clone(), + param: LOG_DENOM_TYPE_PARAM, + }), + } +} + +pub(super) const fn type_arg(log_denom: u8) -> TypeArg { + TypeArg::BoundedNat { + n: log_denom as u64, + } +} + +/// An angle +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct ConstAngle { + log_denom: u8, + value: u64, +} + +impl ConstAngle { + /// Create a new [`ConstAngle`] from a log-denominator and a numerator + pub fn new(log_denom: u8, value: u64) -> Result { + if !is_valid_log_denom(log_denom) { + return Err(ConstTypeError::CustomCheckFail( + hugr::types::CustomCheckFailure::Message( + "Invalid angle log-denominator.".to_owned(), + ), + )); + } + if value >= (1u64 << log_denom) { + return Err(ConstTypeError::CustomCheckFail( + hugr::types::CustomCheckFailure::Message( + "Invalid unsigned integer value.".to_owned(), + ), + )); + } + Ok(Self { log_denom, value }) + } + + /// Create a new [`ConstAngle`] from a log-denominator and a floating-point value in radians, + /// rounding to the nearest corresponding value. (Ties round away from zero.) + pub fn from_radians_rounding(log_denom: u8, theta: f64) -> Result { + if !is_valid_log_denom(log_denom) { + return Err(ConstTypeError::CustomCheckFail( + hugr::types::CustomCheckFailure::Message( + "Invalid angle log-denominator.".to_owned(), + ), + )); + } + let a = (((1u64 << log_denom) as f64) * theta / TAU).round() as i64; + Ok(Self { + log_denom, + value: a.rem_euclid(1i64 << log_denom) as u64, + }) + } + + /// Returns the value of the constant + pub fn value(&self) -> u64 { + self.value + } + + /// Returns the log-denominator of the constant + pub fn log_denom(&self) -> u8 { + self.log_denom + } +} + +#[typetag::serde] +impl CustomConst for ConstAngle { + fn name(&self) -> SmolStr { + format!("a(2π*{}/2^{})", self.value, self.log_denom).into() + } + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + if typ.clone() == super::angle_custom_type(self.log_denom) { + Ok(()) + } else { + Err(CustomCheckFailure::Message( + "Angle constant type mismatch.".into(), + )) + } + } + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + hugr::values::downcast_equal_consts(self, other) + } +} + +fn type_var(var_id: usize, extension: &Extension) -> Result { + Ok(Type::new_extension(angle_def(extension).instantiate( + vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)], + )?)) +} +fn atrunc_sig(extension: &Extension) -> Result { + let in_angle = type_var(0, extension)?; + let out_angle = type_var(1, extension)?; + + Ok(FunctionType::new(vec![in_angle], vec![out_angle])) +} + +fn aconvert_sig(extension: &Extension) -> Result { + let in_angle = type_var(0, extension)?; + let out_angle = type_var(1, extension)?; + Ok(FunctionType::new( + vec![in_angle], + vec![Type::new_sum(vec![out_angle, ERROR_TYPE])], + )) +} + +/// Collect a vector into an array. +fn collect_array(arr: &[T]) -> [&T; N] { + arr.iter().collect_vec().try_into().unwrap() +} + +fn abinop_sig(arg_values: &[TypeArg]) -> Result { + let [arg0, arg1] = collect_array(arg_values); + let m: u8 = get_log_denom(arg0)?; + let n: u8 = get_log_denom(arg1)?; + let l: u8 = max(m, n); + Ok(FunctionType::new( + vec![angle_type(m), angle_type(n)], + vec![angle_type(l)], + )) +} + +fn aunop_sig(extension: &Extension) -> Result { + let angle = type_var(0, extension)?; + Ok(FunctionType::new_linear(vec![angle])) +} + +fn angle_def(extension: &Extension) -> &TypeDef { + extension.get_type(&ANGLE_TYPE_ID).unwrap() +} + +pub(super) fn add_to_extension(extension: &mut Extension) { + extension + .add_type( + ANGLE_TYPE_ID, + vec![LOG_DENOM_TYPE_PARAM], + "angle value with a given log-denominator".to_owned(), + TypeBound::Eq.into(), + ) + .unwrap(); + + let reg1: ExtensionRegistry = [PRELUDE.to_owned(), extension.to_owned()].into(); + extension + .add_op_type_scheme( + "atrunc".into(), + "truncate an angle to one with a lower log-denominator with the same value, rounding \ + down in [0, 2π) if necessary" + .to_owned(), + Default::default(), + vec![], + PolyFuncType::new_validated( + vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], + atrunc_sig(extension).unwrap(), + ®1, + ) + .unwrap(), + ) + .unwrap(); + + extension + .add_op_type_scheme( + "aconvert".into(), + "convert an angle to one with another log-denominator having the same value, if \ + possible, otherwise return an error" + .to_owned(), + Default::default(), + vec![], + PolyFuncType::new_validated( + vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], + aconvert_sig(extension).unwrap(), + ®1, + ) + .unwrap(), + ) + .unwrap(); + + extension + .add_op_custom_sig_simple( + "aadd".into(), + "addition of angles".to_owned(), + vec![LOG_DENOM_TYPE_PARAM], + abinop_sig, + ) + .unwrap(); + + extension + .add_op_custom_sig_simple( + "asub".into(), + "subtraction of the second angle from the first".to_owned(), + vec![LOG_DENOM_TYPE_PARAM], + abinop_sig, + ) + .unwrap(); + + extension + .add_op_type_scheme( + "aneg".into(), + "negation of an angle".to_owned(), + Default::default(), + vec![], + PolyFuncType::new_validated( + vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], + aunop_sig(extension).unwrap(), + ®1, + ) + .unwrap(), + ) + .unwrap(); +} + +#[cfg(test)] +mod test { + use super::*; + use hugr::types::TypeArg; + + #[test] + fn test_angle_log_denoms() { + let type_arg_53 = TypeArg::BoundedNat { n: 53 }; + assert_eq!(get_log_denom(&type_arg_53).unwrap(), 53); + + let type_arg_54 = TypeArg::BoundedNat { n: 54 }; + assert!(matches!( + get_log_denom(&type_arg_54), + Err(TypeArgError::TypeMismatch { .. }) + )); + } + + #[test] + fn test_angle_consts() { + let const_a32_7 = ConstAngle::new(5, 7).unwrap(); + let const_a33_7 = ConstAngle::new(6, 7).unwrap(); + let const_a32_8 = ConstAngle::new(6, 8).unwrap(); + assert_ne!(const_a32_7, const_a33_7); + assert_ne!(const_a32_7, const_a32_8); + assert_eq!(const_a32_7, ConstAngle::new(5, 7).unwrap()); + + assert!(const_a32_7 + .check_custom_type(&super::super::angle_custom_type(5)) + .is_ok()); + assert!(const_a32_7 + .check_custom_type(&super::super::angle_custom_type(6)) + .is_err()); + assert!(matches!( + ConstAngle::new(3, 256), + Err(ConstTypeError::CustomCheckFail(_)) + )); + assert!(matches!( + ConstAngle::new(54, 256), + Err(ConstTypeError::CustomCheckFail(_)) + )); + let const_af1 = ConstAngle::from_radians_rounding(5, 0.21874 * TAU).unwrap(); + assert_eq!(const_af1.value(), 7); + assert_eq!(const_af1.log_denom(), 5); + + assert!(ConstAngle::from_radians_rounding(54, 0.21874 * TAU).is_err()); + + assert!(const_a32_7.equal_consts(&ConstAngle::new(5, 7).unwrap())); + assert_ne!(const_a32_7, const_a33_7); + + assert_eq!(const_a32_8.name(), "a(2π*8/2^6)"); + } + #[test] + fn test_binop_sig() { + let sig = abinop_sig(&[type_arg(23), type_arg(42)]).unwrap(); + + assert_eq!( + sig, + FunctionType::new(vec![angle_type(23), angle_type(42)], vec![angle_type(42)]) + ); + + assert!(abinop_sig(&[type_arg(23), type_arg(89)]).is_err()); + } +} diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index 83b9a270..7d3161a9 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -1,5 +1,8 @@ use std::collections::HashMap; +use crate::extension::{ + SYM_EXPR_T, SYM_OP_ID, TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID, +}; use hugr::{ extension::{ prelude::{BOOL_T, QB_T}, @@ -9,14 +12,13 @@ use hugr::{ std_extensions::arithmetic::float_types::FLOAT64_TYPE, type_row, types::{ - type_param::{CustomTypeArg, TypeArg, TypeParam}, - CustomType, FunctionType, TypeBound, + type_param::{CustomTypeArg, TypeArg}, + FunctionType, }, Extension, }; -use lazy_static::lazy_static; + use serde::{Deserialize, Serialize}; -use smol_str::SmolStr; use std::str::FromStr; use strum::IntoEnumIterator; @@ -28,9 +30,6 @@ use pyo3::pyclass; use crate::extension::REGISTRY; -/// Name of tket 2 extension. -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2"); - #[derive( Clone, Copy, @@ -95,7 +94,9 @@ pub enum Pauli { pub struct NotT2Op; // this trait could be implemented in Hugr -trait SimpleOpEnum: Into<&'static str> + FromStr + Copy + IntoEnumIterator { +pub(crate) trait SimpleOpEnum: + Into<&'static str> + FromStr + Copy + IntoEnumIterator +{ type LoadError: std::error::Error; fn signature(&self) -> FunctionType; @@ -255,42 +256,6 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> { } } -/// The name of the symbolic expression opaque type arg. -pub const SYM_EXPR_NAME: SmolStr = SmolStr::new_inline("SymExpr"); - -/// The name of the symbolic expression opaque type arg. -const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_float"); - -lazy_static! { -/// The type of the symbolic expression opaque type arg. -pub static ref SYM_EXPR_T: CustomType = - EXTENSION.get_type(&SYM_EXPR_NAME).unwrap().instantiate([]).unwrap(); - -pub static ref EXTENSION: Extension = { - let mut e = Extension::new(EXTENSION_ID); - load_all_ops::(&mut e).expect("add fail"); - - let sym_expr_opdef = e.add_type( - SYM_EXPR_NAME, - vec![], - "Symbolic expression.".into(), - TypeBound::Eq.into(), - ) - .unwrap(); - let sym_expr_param = TypeParam::Opaque(sym_expr_opdef.instantiate([]).unwrap()); - - e.add_op_custom_sig_simple( - SYM_OP_ID, - "Store a sympy expression that can be evaluated to a float.".to_string(), - vec![sym_expr_param], - |_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])), - ) - .unwrap(); - - e -}; -} - // From implementations could be made generic over SimpleOpEnum impl From for LeafOp { fn from(op: T2Op) -> Self { @@ -350,7 +315,9 @@ impl TryFrom for T2Op { } /// load all variants of a `SimpleOpEnum` in to an extension as op defs. -fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> { +pub(crate) fn load_all_ops( + extension: &mut Extension, +) -> Result<(), ExtensionBuildError> { for op in T::all_variants() { op.add_to_extension(extension)?; } @@ -364,9 +331,9 @@ pub(crate) mod test { use hugr::{extension::OpDef, Hugr}; use rstest::{fixture, rstest}; + use super::T2Op; + use crate::extension::{TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID}; use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit}; - - use super::{T2Op, EXTENSION, EXTENSION_ID}; fn get_opdef(op: impl SimpleOpEnum) -> Option<&'static Arc> { EXTENSION.get_op(op.name()) }