Skip to content

Commit

Permalink
Merge branch 'doug/cfg' of gh:CQCL/hugr-llvm into doug/init-prelude
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Jun 18, 2024
2 parents 2cdac37 + 74deefc commit f6f028b
Show file tree
Hide file tree
Showing 12 changed files with 647 additions and 46 deletions.
60 changes: 45 additions & 15 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use anyhow::{anyhow, Result};
use delegate::delegate;
use hugr::{
ops::{FuncDecl, FuncDefn, NamedOp as _, OpType},
types::PolyFuncType,
HugrView, Node, NodeIndex,
};
use inkwell::{
context::Context,
module::Module,
module::{Linkage, Module},
types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
values::{BasicValueEnum, FunctionValue, GlobalValue},
};
Expand Down Expand Up @@ -204,39 +205,68 @@ impl<'c, H: HugrView> EmitModuleContext<'c, H> {
fn get_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &hugr::types::PolyFuncType,
func_ty: FunctionType<'c>,
linkage: Option<Linkage>,
) -> Result<FunctionValue<'c>> {
let sig = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(sig)?;
let name = self.name_func(name, node);
let func = self
.module()
.get_function(&name)
.unwrap_or_else(|| self.module.add_function(&name, llvm_func_ty, None));
if func.get_type() != llvm_func_ty {
.get_function(name.as_ref())
.unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
if func.get_type() != func_ty {
Err(anyhow!(
"Function '{name}' has wrong type: hugr: {func_ty} expected: {llvm_func_ty} actual: {}",
"Function '{}' has wrong type: expected: {func_ty} actual: {}",
name.as_ref(),
func.get_type()
))?
}
Ok(func)
}

fn get_hugr_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &PolyFuncType,
) -> Result<FunctionValue<'c>> {
let func_ty = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(func_ty)?;
let name = self.name_func(name, node);
self.get_func_impl(name, llvm_func_ty, None)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDefn].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_defn(&self, node: FatNode<'c, FuncDefn, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDecl].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or get the [FunctionValue] in the [Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(
&self,
symbol: impl AsRef<str>,
typ: FunctionType<'c>,
) -> Result<FunctionValue<'c>> {
self.get_func_impl(symbol, typ, Some(Linkage::External))
}

/// Adds or gets the [GlobalValue] in the [Module] corresponding to the
Expand Down Expand Up @@ -394,7 +424,7 @@ impl<'c, H: HugrView> EmitHugr<'c, H> {
/// It is safe to emit the same node multiple times, it will be detected and
/// omitted.
///
/// If any LLVM IR declaration which is to emitted already exists in the
/// If any LLVM IR declaration which is to be emitted already exists in the
/// [Module] and it differs from what would be emitted, then we fail.
pub fn emit_global(mut self, node: impl Into<Emission<'c, H>>) -> Result<Self> {
let mut worklist: EmissionSet<'c, H> = [node.into()].into_iter().collect();
Expand Down
17 changes: 14 additions & 3 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use super::{Emission, EmissionSet, EmitModuleContext};
mod mailbox;
pub use mailbox::{RowMailBox, RowPromise};

/// A context for the emitting an LLVM function.
/// A context for emitting an LLVM function.
///
/// One of the primary interfaces that impls of
/// [crate::custom::CodegenExtension] and [super::EmitOp] will interface with,
Expand Down Expand Up @@ -59,6 +59,7 @@ pub struct EmitFuncContext<'c, H: HugrView> {
impl<'c, H: HugrView> EmitFuncContext<'c, H> {
delegate! {
to self.emit_context {
/// Returns the inkwell [Context].
fn iw_context(&self) -> &'c Context;
/// Returns the internal [CodegenExtsMap] .
pub fn extensions(&self) -> Rc<CodegenExtsMap<'c,H>>;
Expand All @@ -78,6 +79,18 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// The name of the result may have been mangled.
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>>;
/// Adds or get the [FunctionValue] in the [inkwell::module::Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [inkwell::module::Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
/// Adds or gets the [GlobalValue] in the [inkwell::module::Module] corresponding to the
/// given symbol and LLVM type.
///
Expand Down Expand Up @@ -110,8 +123,6 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
/// Create a new basic block. When `before` is `Some` the block will be
/// created immediately before that block, otherwise at the end of the
/// function.
///
/// TODO I think this will be needed for emitting CFGs.
pub(crate) fn new_basic_block(
&mut self,
name: impl AsRef<str>,
Expand Down
156 changes: 139 additions & 17 deletions src/emit/ops/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> {
}
})
.collect::<Result<HashMap<_, _>>>()?;
let (entry_node, exit_node) = node.get_entry_exit().unwrap();
let (entry_node, exit_node) = node.get_entry_exit();
Ok(CfgEmitter {
context,
bbs,
Expand Down Expand Up @@ -107,21 +107,11 @@ impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> {
// emit each child by delegating to the `impl EmitOp<_>` of self.
for c in self.node.children() {
let (inputs, outputs) = (vec![], RowMailBox::new_empty().promise());
if let Some(node) = c.try_into_ot::<DataflowBlock>() {
self.emit(EmitOpArgs {
node,
inputs,
outputs,
})?;
} else if let Some(node) = c.try_into_ot::<ExitBlock>() {
self.emit(EmitOpArgs {
node,
inputs,
outputs,
})?;
} else {
Err(anyhow!("unknown optype: {c}"))?;
}
self.emit(EmitOpArgs {
node: c.clone(),
inputs,
outputs,
})?;
}

// move the builder to the end of the exit block
Expand Down Expand Up @@ -227,10 +217,14 @@ impl<'c, H: HugrView> EmitOp<'c, ExitBlock, H> for CfgEmitter<'c, '_, H> {
#[cfg(test)]
mod test {
use hugr::builder::{Dataflow, DataflowSubContainer, SubContainer};
use hugr::extension::prelude::BOOL_T;
use hugr::extension::{ExtensionRegistry, ExtensionSet};
use hugr::ops::Value;
use hugr::std_extensions::arithmetic::int_types::{self, INT_TYPES};
use hugr::type_row;

use hugr::types::Type;
use itertools::Itertools as _;
use rstest::rstest;

use crate::custom::int::add_int_extensions;
Expand All @@ -240,7 +234,7 @@ mod test {
use crate::check_emission;

#[rstest]
fn emit_cfg(mut llvm_ctx: TestContext) {
fn diverse_outputs(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let t1 = INT_TYPES[0].clone();
let t2 = INT_TYPES[1].clone();
Expand Down Expand Up @@ -296,4 +290,132 @@ mod test {
});
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn nested(llvm_ctx: TestContext) {
let t1 = Type::new_unit_sum(3);
let es = ExtensionSet::default();
let hugr = SimpleHugrConfig::new()
.with_ins(vec![t1.clone(), BOOL_T])
.with_outs(BOOL_T)
.finish(|mut builder| {
let [in1, in2] = builder.input_wires_arr();
let unit_val = builder.add_load_value(Value::unit());
let [outer_cfg_out] = {
let mut outer_cfg_builder = builder
.cfg_builder(
[(t1.clone(), in1), (BOOL_T, in2)],
None,
BOOL_T.into(),
es.clone(),
)
.unwrap();

let outer_entry_block = {
let mut outer_entry_builder = outer_cfg_builder
.entry_builder([type_row![], type_row![]], type_row![], es.clone())
.unwrap();
let [outer_entry_in1, outer_entry_in2] =
outer_entry_builder.input_wires_arr();
let [outer_entry_out] = {
let mut inner_cfg_builder = outer_entry_builder
.cfg_builder([], None, BOOL_T.into(), es.clone())
.unwrap();
let inner_exit_block = inner_cfg_builder.exit_block();
let inner_entry_block = {
let inner_entry_builder = inner_cfg_builder
.entry_builder(
[type_row![], type_row![], type_row![]],
type_row![],
es.clone(),
)
.unwrap();
// non-local edge
inner_entry_builder
.finish_with_outputs(outer_entry_in1, [])
.unwrap()
};
let [b1, b2, b3] = (0..3)
.map(|i| {
let mut b_builder = inner_cfg_builder
.block_builder(
type_row![],
vec![type_row![]],
es.clone(),
BOOL_T.into(),
)
.unwrap();
let output = match i {
0 => b_builder.add_load_value(Value::true_val()),
1 => b_builder.add_load_value(Value::false_val()),
2 => outer_entry_in2,
_ => unreachable!(),
};
b_builder.finish_with_outputs(unit_val, [output]).unwrap()
})
.collect_vec()
.try_into()
.unwrap();
inner_cfg_builder
.branch(&inner_entry_block, 0, &b1)
.unwrap();
inner_cfg_builder
.branch(&inner_entry_block, 1, &b2)
.unwrap();
inner_cfg_builder
.branch(&inner_entry_block, 2, &b3)
.unwrap();
inner_cfg_builder.branch(&b1, 0, &inner_exit_block).unwrap();
inner_cfg_builder.branch(&b2, 0, &inner_exit_block).unwrap();
inner_cfg_builder.branch(&b3, 0, &inner_exit_block).unwrap();
inner_cfg_builder
.finish_sub_container()
.unwrap()
.outputs_arr()
};

outer_entry_builder
.finish_with_outputs(outer_entry_out, [])
.unwrap()
};

let [b1, b2] = (0..2)
.map(|i| {
let mut b_builder = outer_cfg_builder
.block_builder(
type_row![],
vec![type_row![]],
es.clone(),
BOOL_T.into(),
)
.unwrap();
let output = match i {
0 => b_builder.add_load_value(Value::true_val()),
1 => b_builder.add_load_value(Value::false_val()),
_ => unreachable!(),
};
b_builder.finish_with_outputs(unit_val, [output]).unwrap()
})
.collect_vec()
.try_into()
.unwrap();

let exit_block = outer_cfg_builder.exit_block();
outer_cfg_builder
.branch(&outer_entry_block, 0, &b1)
.unwrap();
outer_cfg_builder
.branch(&outer_entry_block, 1, &b2)
.unwrap();
outer_cfg_builder.branch(&b1, 0, &exit_block).unwrap();
outer_cfg_builder.branch(&b2, 0, &exit_block).unwrap();
outer_cfg_builder
.finish_sub_container()
.unwrap()
.outputs_arr()
};
builder.finish_with_outputs([outer_cfg_out]).unwrap()
});
check_emission!(hugr, llvm_ctx);
}
}
49 changes: 49 additions & 0 deletions src/emit/ops/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
---
source: src/emit/ops/cfg.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
br label %6

2: ; preds = %6
%3 = extractvalue { { i8, i8 } } %9, 0
%4 = extractvalue { i8, i8 } %3, 0
%5 = extractvalue { i8, i8 } %3, 1
br label %15

6: ; preds = %10, %entry_block
%"7_0.0" = phi i8 [ %0, %entry_block ], [ %12, %10 ]
%"7_1.0" = phi i8 [ %1, %entry_block ], [ %5, %10 ]
%7 = insertvalue { i8, i8 } undef, i8 %"7_0.0", 0
%8 = insertvalue { i8, i8 } %7, i8 %"7_1.0", 1
%9 = insertvalue { { i8, i8 } } poison, { i8, i8 } %8, 0
switch i32 0, label %2 [
]

10: ; preds = %15
%11 = extractvalue { i32, { i8 }, {} } %17, 1
%12 = extractvalue { i8 } %11, 0
br label %6

13: ; preds = %15
%14 = extractvalue { i32, { i8 }, {} } %17, 2
br label %19

15: ; preds = %2
%16 = insertvalue { i8 } undef, i8 %4, 0
%17 = insertvalue { i32, { i8 }, {} } { i32 0, { i8 } poison, {} poison }, { i8 } %16, 1
%18 = extractvalue { i32, { i8 }, {} } %17, 0
switch i32 %18, label %10 [
i32 1, label %13
]

19: ; preds = %13
ret i8 %5
}
Loading

0 comments on commit f6f028b

Please sign in to comment.