Skip to content

Commit

Permalink
fix: don't normalise half turns (#137)
Browse files Browse the repository at this point in the history
Closes #136

drive-by: cargo update (hugr 0.13.2)

I have tested the test in
quantinuum-dev/guppy-integration#19 now passes and the
example has expected behaviour with crz oracle.
  • Loading branch information
ss2165 authored Oct 23, 2024
1 parent 2d9bfce commit c73f36b
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 131 deletions.
63 changes: 32 additions & 31 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 9 additions & 31 deletions src/extension/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use lazy_static::lazy_static;

use crate::{
custom::CodegenExtsBuilder,
emit::{emit_value, get_intrinsic, EmitFuncContext, EmitOpArgs},
emit::{emit_value, EmitFuncContext, EmitOpArgs},
types::TypingSession,
CodegenExtension,
};
Expand Down Expand Up @@ -138,9 +138,7 @@ impl<PCG: PreludeCodegen> RotationCodegenExtension<PCG> {
op: RotationOp,
) -> Result<()> {
let ts = context.typing_session();
let module = context.get_current_module();
let builder = context.builder();
let angle_ty = llvm_angle_type(&ts);

match op {
RotationOp::radd => {
Expand Down Expand Up @@ -204,27 +202,7 @@ impl<PCG: PreludeCodegen> RotationCodegenExtension<PCG> {
.map_err(|_| anyhow!("RotationOp::tohalfturns expects one argument"))?;
let half_turns = half_turns.into_float_value();

// normalised_half_turns is in the interval 0..2
let normalised_half_turns = {
// normalised_rads = (half_turns/2 - floor(half_turns/2)) * 2
// note that floor(x) gives the largest integral value less
// than or equal to x so this deals with both positive and
// negative rads.
let turns =
builder.build_float_div(half_turns, angle_ty.const_float(2.0), "")?;
let floor_turns = {
let floor = get_intrinsic(module, "llvm.floor", [angle_ty.into()])?;
builder
.build_call(floor, &[turns.into()], "")?
.try_as_basic_value()
.left()
.ok_or(anyhow!("llvm.floor has no return value"))?
.into_float_value()
};
let normalised_turns = builder.build_float_sub(turns, floor_turns, "")?;
builder.build_float_mul(normalised_turns, angle_ty.const_float(2.0), "")?
};
args.outputs.finish(builder, [normalised_half_turns.into()])
args.outputs.finish(builder, [half_turns.into()])
}
op => bail!("Unsupported op: {op:?}"),
}
Expand Down Expand Up @@ -314,7 +292,7 @@ mod test {

#[rstest]
#[case(ConstRotation::new(1.0).unwrap(), ConstRotation::new(0.5).unwrap(), 1.5)]
#[case(ConstRotation::PI, ConstRotation::new(1.5).unwrap(), 0.5)]
#[case(ConstRotation::PI, ConstRotation::new(1.5).unwrap(), 2.5)]
fn exec_aadd(
mut exec_ctx: TestContext,
#[case] angle1: ConstRotation,
Expand Down Expand Up @@ -350,7 +328,7 @@ mod test {

#[rstest]
#[case(ConstRotation::PI, 1.0)]
#[case(ConstRotation::TAU, 0.0)]
#[case(ConstRotation::TAU, 2.0)]
#[case(ConstRotation::PI_2, 0.5)]
#[case(ConstRotation::PI_4, 0.25)]
fn exec_to_halfturns(
Expand Down Expand Up @@ -420,13 +398,13 @@ mod test {

#[rstest]
#[case(1.0, Some(1.0))]
#[case(-1.0, Some(1.0))]
#[case(-1.0, Some (-1.0))]
#[case(0.5, Some(0.5))]
#[case(-0.5, Some(1.5))]
#[case(-0.5, Some (-0.5))]
#[case(0.25, Some(0.25))]
#[case(-0.25, Some(1.75))]
#[case(13.5, Some(1.5))]
#[case(-13.5, Some(0.5))]
#[case(-0.25, Some (-0.25))]
#[case(13.5, Some(13.5))]
#[case(-13.5, Some (-13.5))]
#[case(f64::NAN, None)]
#[case(f64::INFINITY, None)]
#[case(f64::NEG_INFINITY, None)]
Expand Down
57 changes: 24 additions & 33 deletions src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -29,54 +29,45 @@ entry_block: ; preds = %alloca_block
unreachable

9: ; preds = %entry_block
%10 = fdiv double %0, 2.000000e+00
%11 = call double @llvm.floor.f64(double %10)
%12 = fsub double %10, %11
%13 = fmul double %12, 2.000000e+00
%14 = fcmp oeq double %13, 0x7FF0000000000000
%15 = fcmp oeq double %13, 0xFFF0000000000000
%16 = fcmp uno double %13, 0.000000e+00
%17 = or i1 %14, %15
%18 = or i1 %17, %16
%19 = xor i1 %18, true
%20 = insertvalue { double } undef, double %13, 0
%21 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %20, 2
%22 = select i1 %19, { i32, {}, { double } } %21, { i32, {}, { double } } { i32 0, {} undef, { double } poison }
%23 = extractvalue { i32, {}, { double } } %22, 0
switch i32 %23, label %24 [
i32 1, label %26
%10 = fcmp oeq double %0, 0x7FF0000000000000
%11 = fcmp oeq double %0, 0xFFF0000000000000
%12 = fcmp uno double %0, 0.000000e+00
%13 = or i1 %10, %11
%14 = or i1 %13, %12
%15 = xor i1 %14, true
%16 = insertvalue { double } undef, double %0, 0
%17 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %16, 2
%18 = select i1 %15, { i32, {}, { double } } %17, { i32, {}, { double } } { i32 0, {} undef, { double } poison }
%19 = extractvalue { i32, {}, { double } } %18, 0
switch i32 %19, label %20 [
i32 1, label %22
]

24: ; preds = %9
%25 = extractvalue { i32, {}, { double } } %22, 1
20: ; preds = %9
%21 = extractvalue { i32, {}, { double } } %18, 1
br label %cond_7_case_0

26: ; preds = %9
%27 = extractvalue { i32, {}, { double } } %22, 2
%28 = extractvalue { double } %27, 0
22: ; preds = %9
%23 = extractvalue { i32, {}, { double } } %18, 2
%24 = extractvalue { double } %23, 0
br label %cond_7_case_1

cond_7_case_0: ; preds = %24
%29 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0
%30 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1
%31 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %29, i8* %30)
cond_7_case_0: ; preds = %20
%25 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0
%26 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1
%27 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %25, i8* %26)
call void @abort()
br label %cond_exit_7

cond_7_case_1: ; preds = %26
cond_7_case_1: ; preds = %22
br label %cond_exit_7

cond_exit_7: ; preds = %cond_7_case_1, %cond_7_case_0
%"0.0" = phi double [ 0.000000e+00, %cond_7_case_0 ], [ %28, %cond_7_case_1 ]
%32 = fadd double %0, %"0.0"
%"0.0" = phi double [ 0.000000e+00, %cond_7_case_0 ], [ %24, %cond_7_case_1 ]
%28 = fadd double %0, %"0.0"
ret void
}

declare i32 @printf(i8*, ...)

declare void @abort()

; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
declare double @llvm.floor.f64(double) #0

attributes #0 = { nofree nosync nounwind readnone speculatable willreturn }
Loading

0 comments on commit c73f36b

Please sign in to comment.