Skip to content

Commit

Permalink
test: Add tests for exact float <-> int roundtrips
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Oct 20, 2024
1 parent 176d98f commit d239a64
Showing 1 changed file with 65 additions and 47 deletions.
112 changes: 65 additions & 47 deletions src/extension/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, ensure, Result};

use hugr::{
extension::{
prelude::{sum_with_error, ConstError, BOOL_T},
simple_op::MakeExtensionOp,
},
ops::{constant::Value, custom::ExtensionOp},
ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _},
std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES},
types::{TypeArg, TypeEnum},
types::{TypeArg, TypeEnum, TypeRow},
HugrView,
};

use inkwell::{values::BasicValue, FloatPredicate, IntPredicate};
use inkwell::{
types::IntType,
values::BasicValue,
FloatPredicate, IntPredicate,
};

use crate::{
custom::{CodegenExtension, CodegenExtsBuilder},
Expand All @@ -21,6 +25,7 @@ use crate::{
EmitOpArgs,
},
sum::LLVMSumValue,
types::HugrType,
};

fn build_trunc_op<'c, H: HugrView>(
Expand All @@ -29,38 +34,45 @@ fn build_trunc_op<'c, H: HugrView>(
log_width: u64,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
// Note: This logic is copied from `llvm_type` in the IntTypes
// extension. We need to have a common source of truth for this.
let (width, (int_min_value_s, int_max_value_s), int_max_value_u) = match log_width {
0..=3 => (8, (i8::MIN as i64, i8::MAX as i64), u8::MAX as u64),
4 => (16, (i16::MIN as i64, i16::MAX as i64), u16::MAX as u64),
5 => (32, (i32::MIN as i64, i32::MAX as i64), u32::MAX as u64),
6 => (64, (i64::MIN, i64::MAX), u64::MAX),
m => return Err(anyhow!("ConversionEmitter: unsupported log_width: {}", m)),
};

let hugr_int_ty = INT_TYPES[log_width as usize].clone();
let int_ty = context
.typing_session()
.llvm_type(&hugr_int_ty)?
.into_int_type();
let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]);
// TODO: it would be nice to get this info out of `ops.node()`, this would
// require adding appropriate methods to `ConvertOpDef`. In the meantime, we
// assert that the output types are as we expect.
debug_assert_eq!(
TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]),
args.node().signature().output
);

let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else {
bail!("Expected `arithmetic.int` to lower to an llvm integer")
};

let hugr_sum_ty = sum_with_error(vec![hugr_int_ty]);
let sum_ty = context.typing_session().llvm_sum_type(hugr_sum_ty)?;
let sum_ty = context.llvm_sum_type(hugr_sum_ty)?;

let (width, int_min_value_s, int_max_value_s, int_max_value_u) = {
ensure!(
log_width <= 6,
"Expected log_width of output to be <= 6, found: {log_width}"
);
let width = 1 << log_width;
(
width,
i64::MIN >> (64 - width),
i64::MAX >> (64 - width),
u64::MAX >> (64 - width),
)
};

emit_custom_unary_op(context, args, |ctx, arg, _| {
// We have to check if the conversion will work, so we
// make the maximum int and convert to a float, then compare
// with the function input.
let flt_max = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_s as f64)
let flt_max = ctx.iw_context().f64_type().const_float(if signed {
int_max_value_s as f64
} else {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_u as f64)
};
int_max_value_u as f64
});

let within_upper_bound = ctx.builder().build_float_compare(
FloatPredicate::OLT,
Expand All @@ -69,13 +81,11 @@ fn build_trunc_op<'c, H: HugrView>(
"within_upper_bound",
)?;

let flt_min = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_min_value_s as f64)
let flt_min = ctx.iw_context().f64_type().const_float(if signed {
int_min_value_s as f64
} else {
ctx.iw_context().f64_type().const_float(0.0)
};
0.0
});

let within_lower_bound = ctx.builder().build_float_compare(
FloatPredicate::OLE,
Expand Down Expand Up @@ -414,26 +424,20 @@ mod test {
.outputs_arr();
let [flt] = {
let op = if signed {
ConvertOpDef::convert_s.with_log_width(6)
ConvertOpDef::convert_s.with_log_width(6)
} else {
ConvertOpDef::convert_u.with_log_width(6)
ConvertOpDef::convert_u.with_log_width(6)
};
builder
.add_dataflow_op(op, [int])
.unwrap()
.outputs_arr()
builder.add_dataflow_op(op, [int]).unwrap().outputs_arr()
};

let [int_or_err] = {
let op = if signed {
ConvertOpDef::trunc_s.with_log_width(6)
ConvertOpDef::trunc_s.with_log_width(6)
} else {
ConvertOpDef::trunc_u.with_log_width(6)
ConvertOpDef::trunc_u.with_log_width(6)
};
builder
.add_dataflow_op(op, [flt])
.unwrap()
.outputs_arr()
builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr()
};
let sum_ty = sum_with_error(int64.clone());
let variants = (0..sum_ty.num_variants())
Expand Down Expand Up @@ -482,12 +486,26 @@ mod test {
#[case(4294967295)]
#[case(42)]
#[case(18_000_000_000_000_000_000)]
fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: u64) {
fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val, false);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
}

#[rstest]
// Exact roundtrip conversion is defined on values up to 2**53 for f64.
#[case(0)]
#[case(3)]
#[case(255)]
#[case(4294967295)]
#[case(42)]
#[case(-9_000_000_000_000_000_000)]
fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val as u64, true);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64);
}

// For unisgined ints larger than (1 << 54) - 1, f64s do not have enough
// precision to exactly roundtrip the int.
// The exact behaviour of the round-trip is is platform-dependent.
Expand Down

0 comments on commit d239a64

Please sign in to comment.