Skip to content

Commit

Permalink
feat: Implement lowerings for ieq,ilt_s,sub in int codegen extension (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q authored Jun 20, 2024
1 parent 1f1c633 commit 8d86755
Show file tree
Hide file tree
Showing 13 changed files with 358 additions and 38 deletions.
2 changes: 0 additions & 2 deletions src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ use super::emit::EmitOp;

pub mod int;
pub mod prelude;
// pub mod float_ops;
// pub mod logic_ops;

/// The extension point for lowering HUGR Extensions to LLVM.
pub trait CodegenExtension<'c, H: HugrView> {
Expand Down
214 changes: 182 additions & 32 deletions src/custom/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{any::TypeId, collections::HashSet};

use hugr::{
extension::{simple_op::MakeExtensionOp, ExtensionId},
ops::{constant::CustomConst, CustomOp, NamedOp},
ops::{constant::CustomConst, CustomOp, NamedOp, Value},
std_extensions::arithmetic::{
int_ops::{self, ConcreteIntOp},
int_types::{self, ConstInt},
Expand All @@ -11,12 +11,14 @@ use hugr::{
HugrView,
};
use inkwell::{
builder::Builder,
types::{BasicTypeEnum, IntType},
values::BasicValueEnum,
values::{BasicValue, BasicValueEnum},
};
use itertools::{zip_eq, Itertools as _};

use crate::{
emit::{func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm},
emit::{emit_value, func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm},
types::TypingSession,
};

Expand All @@ -25,18 +27,76 @@ use anyhow::{anyhow, Result};

struct IntOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>);

// TODO this is probably useful enough to offer as a general utility
fn emit_custom_binary_op<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, CustomOp, H>,
go: impl FnOnce(
&Builder<'c>,
BasicValueEnum<'c>,
BasicValueEnum<'c>,
) -> Result<Vec<BasicValueEnum<'c>>>,
) -> Result<()> {
let [lhs, rhs] = TryInto::<[_; 2]>::try_into(args.inputs).map_err(|v| {
anyhow!(
"emit_custom_2_to_1_op: expected exactly 2 inputs, got {}",
v.len()
)
})?;
if lhs.get_type() != rhs.get_type() {
return Err(anyhow!(
"emit_custom_2_to_1_op: expected inputs of the same type, got {} and {}",
lhs.get_type(),
rhs.get_type()
));
}
let res = go(context.builder(), lhs, rhs)?;
if res.len() != args.outputs.len()
|| zip_eq(res.iter(), args.outputs.get_types()).any(|(a, b)| a.get_type() != b)
{
return Err(anyhow!(
"emit_custom_2_to_1_op: expected outputs of types {:?}, got {:?}",
args.outputs.get_types().collect_vec(),
res.iter().map(BasicValueEnum::get_type).collect_vec()
));
}
args.outputs.finish(context.builder(), res)
}

fn emit_icmp<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, CustomOp, H>,
pred: inkwell::IntPredicate,
) -> Result<()> {
let true_val = emit_value(context, &Value::true_val())?;
let false_val = emit_value(context, &Value::false_val())?;

emit_custom_binary_op(context, args, |builder, lhs, rhs| {
// get result as an i1
let r = builder.build_int_compare(pred, lhs.into_int_value(), rhs.into_int_value(), "")?;
// convert to whatever BOOL_T is
Ok(vec![builder.build_select(r, true_val, false_val, "")?])
})
}

impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for IntOpEmitter<'c, '_, H> {
fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> {
let iot = ConcreteIntOp::from_optype(&args.node().generalise())
.ok_or(anyhow!("IntOpEmitter from_optype_failed"))?;
match iot.name().as_str() {
"iadd" => {
let builder = self.0.builder();
let [lhs, rhs] = TryInto::<[_; 2]>::try_into(args.inputs).unwrap();
let a = builder.build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?;
args.outputs.finish(builder, [a.into()])
}
_ => Err(anyhow!("IntOpEmitter: unknown name")),
"iadd" => emit_custom_binary_op(self.0, args, |builder, lhs, rhs| {
Ok(vec![builder
.build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
"ieq" => emit_icmp(self.0, args, inkwell::IntPredicate::EQ),
"ilt_s" => emit_icmp(self.0, args, inkwell::IntPredicate::SLT),
"isub" => emit_custom_binary_op(self.0, args, |builder, lhs, rhs| {
Ok(vec![builder
.build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
n => Err(anyhow!("IntOpEmitter: unimplemented op: {n}")),
}
}
}
Expand Down Expand Up @@ -69,28 +129,6 @@ impl<'c, H: HugrView> CodegenExtension<'c, H> for IntOpsCodegenExtension {
) -> Box<dyn EmitOp<'c, CustomOp, H> + 'a> {
Box::new(IntOpEmitter(context))
}

fn supported_consts(&self) -> HashSet<TypeId> {
[TypeId::of::<ConstInt>()].into_iter().collect()
}

fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn hugr::ops::constant::CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
let Some(k) = konst.downcast_ref::<ConstInt>() else {
return Ok(None);
};
let ty: IntType<'c> = context
.llvm_type(&k.get_type())?
.try_into()
.map_err(|_| anyhow!("Failed to get ConstInt as IntType"))?;
// k.value_u() is in two's complement representation of the exactly
// correct bit width, so we are safe to unconditionally retrieve the
// unsigned value and do no sign extension.
Ok(Some(ty.const_int(k.value_u(), false).into()))
}
}

/// A [CodegenExtension] for the [hugr::std_extensions::arithmetic::int_types]
Expand Down Expand Up @@ -147,6 +185,28 @@ impl<'c, H: HugrView> CodegenExtension<'c, H> for IntTypesCodegenExtension {
) -> Box<dyn EmitOp<'c, CustomOp, H> + 'a> {
Box::new(NullEmitLlvm)
}

fn supported_consts(&self) -> HashSet<TypeId> {
[TypeId::of::<ConstInt>()].into_iter().collect()
}

fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn hugr::ops::constant::CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
let Some(k) = konst.downcast_ref::<ConstInt>() else {
return Ok(None);
};
let ty: IntType<'c> = context
.llvm_type(&k.get_type())?
.try_into()
.map_err(|_| anyhow!("Failed to get ConstInt as IntType"))?;
// k.value_u() is in two's complement representation of the exactly
// correct bit width, so we are safe to unconditionally retrieve the
// unsigned value and do no sign extension.
Ok(Some(ty.const_int(k.value_u(), false).into()))
}
}

/// Populates a [CodegenExtsMap] with all extensions needed to lower int ops,
Expand All @@ -155,3 +215,93 @@ pub fn add_int_extensions<H: HugrView>(cem: CodegenExtsMap<'_, H>) -> CodegenExt
cem.add_cge(IntOpsCodegenExtension)
.add_cge(IntTypesCodegenExtension)
}

impl<H: HugrView> CodegenExtsMap<'_, H> {
/// Populates a [CodegenExtsMap] with all extensions needed to lower int ops,
/// types, and constants.
pub fn add_int_extensions(self) -> Self {
add_int_extensions(self)
}
}

#[cfg(test)]
mod test {
use hugr::{
builder::{Dataflow, DataflowSubContainer},
extension::prelude::BOOL_T,
std_extensions::arithmetic::{int_ops, int_types::INT_TYPES},
types::TypeRow,
Hugr,
};
use rstest::rstest;

use crate::{
check_emission,
custom::int::add_int_extensions,
emit::test::SimpleHugrConfig,
test::{llvm_ctx, TestContext},
};

fn test_binary_int_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
test_binary_int_op_with_results(name, log_width, vec![ty.clone()])
}

fn test_binary_icmp_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
test_binary_int_op_with_results(name, log_width, vec![BOOL_T])
}
fn test_binary_int_op_with_results(
name: impl AsRef<str>,
log_width: u8,
output_types: impl Into<TypeRow>,
) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
SimpleHugrConfig::new()
.with_ins(vec![ty.clone(), ty.clone()])
.with_outs(output_types.into())
.with_extensions(int_ops::INT_OPS_REGISTRY.clone())
.finish(|mut hugr_builder| {
let [in1, in2] = hugr_builder.input_wires_arr();
let ext_op = int_ops::EXTENSION
.instantiate_extension_op(
name.as_ref(),
[(log_width as u64).into()],
&int_ops::INT_OPS_REGISTRY,
)
.unwrap();
let outputs = hugr_builder
.add_dataflow_op(ext_op, [in1, in2])
.unwrap()
.outputs();
hugr_builder.finish_with_outputs(outputs).unwrap()
})
}

#[rstest]
fn iadd(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op("iadd", 3);
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn isub(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op("isub", 6);
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn ieq(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_icmp_op("ieq", 1);
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn ilt_s(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_icmp_op("ilt_s", 0);
check_emission!(hugr, llvm_ctx);
}
}
15 changes: 15 additions & 0 deletions src/custom/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: src/custom/int.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = add i8 %0, %1
ret i8 %2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
source: src/custom/int.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca i8, align 1
%"2_1" = alloca i8, align 1
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
store i8 %1, i8* %"2_1", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%"2_12" = load i8, i8* %"2_1", align 1
%2 = add i8 %"2_01", %"2_12"
store i8 %2, i8* %"4_0", align 1
%"4_03" = load i8, i8* %"4_0", align 1
store i8 %"4_03", i8* %"0", align 1
%"04" = load i8, i8* %"0", align 1
ret i8 %"04"
}
16 changes: 16 additions & 0 deletions src/custom/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: src/custom/int.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { i32, {}, {} } @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = icmp eq i8 %0, %1
%3 = select i1 %2, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
ret { i32, {}, {} } %3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
source: src/custom/int.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { i32, {}, {} } @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
%"0" = alloca { i32, {}, {} }, align 8
%"2_0" = alloca i8, align 1
%"2_1" = alloca i8, align 1
%"4_0" = alloca { i32, {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
store i8 %1, i8* %"2_1", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%"2_12" = load i8, i8* %"2_1", align 1
%2 = icmp eq i8 %"2_01", %"2_12"
%3 = select i1 %2, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
store { i32, {}, {} } %3, { i32, {}, {} }* %"4_0", align 4
%"4_03" = load { i32, {}, {} }, { i32, {}, {} }* %"4_0", align 4
store { i32, {}, {} } %"4_03", { i32, {}, {} }* %"0", align 4
%"04" = load { i32, {}, {} }, { i32, {}, {} }* %"0", align 4
ret { i32, {}, {} } %"04"
}
16 changes: 16 additions & 0 deletions src/custom/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: src/custom/int.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { i32, {}, {} } @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = icmp slt i8 %0, %1
%3 = select i1 %2, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
ret { i32, {}, {} } %3
}
Loading

0 comments on commit 8d86755

Please sign in to comment.