Skip to content

Commit

Permalink
use polyfunctype for angles where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 9, 2023
1 parent d7fef9d commit b636a1e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 56 deletions.
5 changes: 5 additions & 0 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ fn json_op_signature(args: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(op.signature())
}

/// Angle type with given log denominator.
pub fn angle_custom_type(log_denom: u8) -> CustomType {
angle::angle_custom_type(&TKET2_EXTENSION, angle::type_arg(log_denom))
}

/// Name of tket 2 extension.
pub const TKET2_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand Down
120 changes: 64 additions & 56 deletions tket2/src/extension/angle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{cmp::max, num::NonZeroU64};

use hugr::{
extension::{prelude::ERROR_TYPE, SignatureError},
extension::{prelude::ERROR_TYPE, ExtensionRegistry, SignatureError, TypeDef, PRELUDE},
types::{
type_param::{TypeArgError, TypeParam},
ConstTypeError, CustomCheckFailure, CustomType, FunctionType, Type, TypeArg, TypeBound,
ConstTypeError, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeArg,
TypeBound,
},
values::CustomConst,
Extension,
Expand All @@ -13,26 +14,11 @@ use itertools::Itertools;
use smol_str::SmolStr;
use std::f64::consts::TAU;

use super::TKET2_EXTENSION_ID;

/// Identifier for the angle type.
const ANGLE_TYPE_ID: SmolStr = SmolStr::new_inline("angle");

fn angle_custom_type(log_denom_arg: TypeArg) -> CustomType {
CustomType::new(
ANGLE_TYPE_ID,
[log_denom_arg],
TKET2_EXTENSION_ID,
TypeBound::Eq,
)
}

/// Angle type with a given log-denominator (specified by the TypeArg).
///
/// This type is capable of representing angles that are multiples of 2π / 2^N where N is the
/// log-denominator.
pub(super) fn angle_type(log_denom_arg: TypeArg) -> Type {
Type::new_extension(angle_custom_type(log_denom_arg))
pub(super) fn angle_custom_type(extension: &Extension, log_denom_arg: TypeArg) -> CustomType {
angle_def(extension).instantiate([log_denom_arg]).unwrap()
}

/// The largest permitted log-denominator.
Expand All @@ -47,7 +33,7 @@ pub const LOG_DENOM_TYPE_PARAM: TypeParam =
TypeParam::bounded_nat(NonZeroU64::MIN.saturating_add(LOG_DENOM_MAX as u64));

/// Get the log-denominator of the specified type argument or error if the argument is invalid.
pub(super) fn get_log_denom(arg: &TypeArg) -> Result<u8, TypeArgError> {
fn get_log_denom(arg: &TypeArg) -> Result<u8, TypeArgError> {
match arg {
TypeArg::BoundedNat { n } if is_valid_log_denom(*n as u8) => Ok(*n as u8),
_ => Err(TypeArgError::TypeMismatch {
Expand Down Expand Up @@ -124,7 +110,7 @@ impl CustomConst for ConstAngle {
format!("a(2π*{}/2^{})", self.value, self.log_denom).into()
}
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> {
if typ.clone() == angle_custom_type(type_arg(self.log_denom)) {
if typ.clone() == super::angle_custom_type(self.log_denom) {
Ok(())
} else {
Err(CustomCheckFailure::Message(
Expand All @@ -136,49 +122,52 @@ impl CustomConst for ConstAngle {
hugr::values::downcast_equal_consts(self, other)
}
}
/// Collect a vector into an array.
fn collect_array<const N: usize, T: std::fmt::Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()

fn type_var(var_id: usize, extension: &Extension) -> Result<Type, SignatureError> {
Ok(Type::new_extension(angle_def(extension).instantiate(
vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)],
)?))
}
fn atrunc_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let in_angle = type_var(0, extension)?;
let out_angle = type_var(1, extension)?;

fn atrunc_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_denom(arg0)?;
let n: u8 = get_log_denom(arg1)?;
if m < n {
return Err(SignatureError::InvalidTypeArgs);
}
Ok(FunctionType::new(
vec![angle_type(arg0.clone())],
vec![angle_type(arg1.clone())],
))
Ok(FunctionType::new(vec![in_angle], vec![out_angle]))
}

fn aconvert_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
fn aconvert_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let in_angle = type_var(0, extension)?;
let out_angle = type_var(1, extension)?;
Ok(FunctionType::new(
vec![angle_type(arg0.clone())],
vec![Type::new_sum(vec![angle_type(arg1.clone()), ERROR_TYPE])],
vec![in_angle],
vec![Type::new_sum(vec![out_angle, ERROR_TYPE])],
))
}

/// Collect a vector into an array.
fn collect_array<const N: usize, T: std::fmt::Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()
}

fn abinop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_denom(arg0)?;
let n: u8 = get_log_denom(arg1)?;
let l: u8 = max(m, n);
let ang_typ = |n| Type::new_extension(super::angle_custom_type(n));
Ok(FunctionType::new(
vec![
angle_type(TypeArg::BoundedNat { n: m as u64 }),
angle_type(TypeArg::BoundedNat { n: n as u64 }),
],
vec![angle_type(TypeArg::BoundedNat { n: l as u64 })],
vec![ang_typ(n), ang_typ(m)],
vec![ang_typ(l)],
))
}

fn aunop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new_linear(vec![angle_type(arg.clone())]))
fn aunop_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let angle = type_var(0, extension)?;
Ok(FunctionType::new_linear(vec![angle]))
}

fn angle_def(extension: &Extension) -> &TypeDef {
extension.get_type(&ANGLE_TYPE_ID).unwrap()
}

pub(super) fn add_to_extension(extension: &mut Extension) {
Expand All @@ -191,25 +180,38 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
)
.unwrap();

let reg1: ExtensionRegistry = [PRELUDE.to_owned(), extension.to_owned()].into();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"atrunc".into(),
"truncate an angle to one with a lower log-denominator with the same value, rounding \
down in [0, 2π) if necessary"
.to_owned(),
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
atrunc_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
atrunc_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();

extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"aconvert".into(),
"convert an angle to one with another log-denominator having the same value, if \
possible, otherwise return an error"
.to_owned(),
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aconvert_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aconvert_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();

Expand All @@ -232,11 +234,17 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
.unwrap();

extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"aneg".into(),
"negation of an angle".to_owned(),
vec![LOG_DENOM_TYPE_PARAM],
aunop_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aunop_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();
}
Expand Down

0 comments on commit b636a1e

Please sign in to comment.