Skip to content

Commit

Permalink
feat!: Allow extension callbacks to have non-'static lifetimes (#128)
Browse files Browse the repository at this point in the history
Closes #127.

BREAKING CHANGE: Add a lifetime argument to `TypingSession`,
`EmitModuleContext`, `EmitHugr`, `EmitFuncContext`.
  • Loading branch information
doug-q authored Oct 14, 2024
1 parent ef55558 commit f7c86d5
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 101 deletions.
72 changes: 71 additions & 1 deletion src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
Expand Down Expand Up @@ -152,3 +152,73 @@ pub struct CodegenExtsMap<'a, H> {
pub extension_op_handlers: Rc<ExtensionOpMap<'a, H>>,
pub type_converter: Rc<TypeConverter<'a>>,
}

#[cfg(test)]
mod test {
use hugr::{
extension::prelude::{ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE, STRING_TYPE_NAME},
Hugr,
};
use inkwell::{
context::Context,
types::BasicType,
values::{BasicMetadataValueEnum, BasicValue},
};
use itertools::Itertools as _;

use crate::{emit::libc::emit_libc_printf, CodegenExtsBuilder};

#[test]
fn types_with_lifetimes() {
let n = "name_with_lifetime".to_string();

let cem = CodegenExtsBuilder::<Hugr>::default()
.custom_type((PRELUDE_ID, STRING_TYPE_NAME), |session, _| {
let ctx = session.iw_context();
Ok(ctx
.get_struct_type(n.as_ref())
.unwrap_or_else(|| ctx.opaque_struct_type(n.as_ref()))
.as_basic_type_enum())
})
.finish();

let ctx = Context::create();

let ty = cem
.type_converter
.session(&ctx)
.llvm_type(&STRING_TYPE)
.unwrap()
.into_struct_type();
let ty_n = ty.get_name().unwrap().to_str().unwrap();
assert_eq!(ty_n, n);
}

#[test]
fn custom_const_lifetime_of_context() {
let ctx = Context::create();

let _ = CodegenExtsBuilder::<Hugr>::default()
.custom_const::<ConstString>(|_, konst| {
Ok(ctx
.const_string(konst.value().as_bytes(), true)
.as_basic_value_enum())
})
.finish();
}

#[test]
fn extension_op_lifetime() {
let ctx = Context::create();

let _ = CodegenExtsBuilder::<Hugr>::default()
.extension_op(PRELUDE_ID, PRINT_OP_ID, |context, args| {
let mut print_args: Vec<BasicMetadataValueEnum> =
vec![ctx.const_string("%s".as_bytes(), true).into()];
print_args.extend(args.inputs.into_iter().map_into::<BasicMetadataValueEnum>());
emit_libc_printf(context, &print_args)?;
args.outputs.finish(context.builder(), [])
})
.finish();
}
}
8 changes: 4 additions & 4 deletions src/custom/extension_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ use crate::emit::{EmitFuncContext, EmitOpArgs};
///
/// Callbacks may hold references with lifetimes older than `'a`.
pub trait ExtensionOpFn<'a, H>:
for<'c> Fn(&mut EmitFuncContext<'c, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
{
}

impl<
'a,
H,
F: for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()>
+ ?Sized
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> {
&mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
Expand All @@ -96,7 +96,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> {
/// If no handler is registered for the op an error will be returned.
pub fn emit_extension_op<'c>(
&self,
context: &mut EmitFuncContext<'c, H>,
context: &mut EmitFuncContext<'c, 'a, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
let node = args.node();
Expand Down
8 changes: 5 additions & 3 deletions src/custom/load_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ use crate::emit::EmitFuncContext;
///
/// Callbacks may hold references with lifetimes older than `'a`.
pub trait LoadConstantFn<'a, H: ?Sized, CC: CustomConst + ?Sized>:
for<'c> Fn(&mut EmitFuncContext<'c, H>, &CC) -> Result<BasicValueEnum<'c>> + 'a
for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>> + 'a
{
}

impl<
'a,
H: ?Sized,
CC: ?Sized + CustomConst,
F: 'a + ?Sized + for<'c> Fn(&mut EmitFuncContext<'c, H>, &CC) -> Result<BasicValueEnum<'c>>,
F: 'a
+ ?Sized
+ for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>>,
> LoadConstantFn<'a, H, CC> for F
{
}
Expand Down Expand Up @@ -59,7 +61,7 @@ impl<'a, H: HugrView> LoadConstantsMap<'a, H> {
/// appropriate inner callbacks.
pub fn emit_load_constant<'c>(
&self,
context: &mut EmitFuncContext<'c, H>,
context: &mut EmitFuncContext<'c, 'a, H>,
konst: &dyn CustomConst,
) -> Result<BasicValueEnum<'c>> {
let type_id = konst.type_id();
Expand Down
16 changes: 9 additions & 7 deletions src/custom/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::marker::PhantomData;

use itertools::Itertools as _;

use hugr::types::CustomType;
Expand All @@ -14,22 +16,22 @@ use crate::{
};

pub trait LLVMCustomTypeFn<'a>:
for<'c> Fn(TypingSession<'c>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
{
}

impl<
'a,
F: for<'c> Fn(TypingSession<'c>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
F: for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
> LLVMCustomTypeFn<'a> for F
{
}

#[derive(Default, Clone)]
pub struct LLVMTypeMapping;
pub struct LLVMTypeMapping<'a>(PhantomData<&'a ()>);

impl TypeMapping for LLVMTypeMapping {
type InV<'c> = TypingSession<'c>;
impl<'a> TypeMapping for LLVMTypeMapping<'a> {
type InV<'c> = TypingSession<'c, 'a>;

type OutV<'c> = BasicTypeEnum<'c>;

Expand All @@ -48,7 +50,7 @@ impl TypeMapping for LLVMTypeMapping {
fn map_sum_type<'c>(
&self,
sum_type: &HugrSumType,
context: TypingSession<'c>,
context: TypingSession<'c, 'a>,
variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
) -> Result<Self::SumOutV<'c>> {
LLVMSumType::try_new2(
Expand All @@ -61,7 +63,7 @@ impl TypeMapping for LLVMTypeMapping {
fn map_function_type<'c>(
&self,
_: &HugrFuncType,
context: TypingSession<'c>,
context: TypingSession<'c, 'a>,
inputs: impl IntoIterator<Item = Self::OutV<'c>>,
outputs: impl IntoIterator<Item = Self::OutV<'c>>,
) -> Result<Self::FuncOutV<'c>> {
Expand Down
30 changes: 18 additions & 12 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ pub use ops::emit_value;
/// This includes the module itself, a set of extensions for lowering custom
/// elements, and policy for naming various HUGR elements.
///
/// `'c` names the lifetime of the LLVM context.
// TODO add another lifetime parameter for `extensions` below.
pub struct EmitModuleContext<'c, H> {
/// `'c` names the lifetime of the LLVM context, while `'a` names the lifetime
/// of other internal references.
pub struct EmitModuleContext<'c, 'a, H>
where
'a: 'c,
{
iw_context: &'c Context,
module: Module<'c>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
namer: Rc<Namer>,
}

impl<'c, H> EmitModuleContext<'c, H> {
impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
delegate! {
to self.typing_session() {
/// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
Expand All @@ -70,7 +73,7 @@ impl<'c, H> EmitModuleContext<'c, H> {
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
Self {
iw_context,
Expand All @@ -88,12 +91,12 @@ impl<'c, H> EmitModuleContext<'c, H> {
}

/// Returns a reference to the inner [CodegenExtsMap].
pub fn extensions(&self) -> Rc<CodegenExtsMap<'static, H>> {
pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
self.extensions.clone()
}

/// Returns a [TypingSession] constructed from it's members.
pub fn typing_session(&self) -> TypingSession<'c> {
pub fn typing_session(&self) -> TypingSession<'c, 'a> {
self.extensions
.type_converter
.clone()
Expand Down Expand Up @@ -235,12 +238,15 @@ impl<'c, H> EmitModuleContext<'c, H> {
type EmissionSet = HashSet<Node>;

/// Emits [HugrView]s into an LLVM [Module].
pub struct EmitHugr<'c, H> {
pub struct EmitHugr<'c, 'a, H>
where
'a: 'c,
{
emitted: EmissionSet,
module_context: EmitModuleContext<'c, H>,
module_context: EmitModuleContext<'c, 'a, H>,
}

impl<'c, H: HugrView> EmitHugr<'c, H> {
impl<'c, 'a, H: HugrView> EmitHugr<'c, 'a, H> {
delegate! {
to self.module_context {
/// Returns a reference to the inner [Context].
Expand All @@ -257,7 +263,7 @@ impl<'c, H: HugrView> EmitHugr<'c, H> {
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
assert_eq!(iw_context, &module.get_context());
Self {
Expand Down
25 changes: 13 additions & 12 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ pub use mailbox::{RowMailBox, RowPromise};
/// [MailBox](RowMailBox)es are stack allocations that are `alloca`ed in the
/// first basic block of the function, read from to get the input values of each
/// node, and written to with the output values of each node.
///
// TODO add another lifetime parameter which `emit_context` will need.
pub struct EmitFuncContext<'c, H> {
emit_context: EmitModuleContext<'c, H>,
pub struct EmitFuncContext<'c, 'a, H>
where
'a: 'c,
{
emit_context: EmitModuleContext<'c, 'a, H>,
todo: EmissionSet,
func: FunctionValue<'c>,
env: HashMap<Wire, ValueMailBox<'c>>,
Expand All @@ -55,15 +56,15 @@ pub struct EmitFuncContext<'c, H> {
launch_bb: BasicBlock<'c>,
}

impl<'c, H: HugrView> EmitFuncContext<'c, H> {
impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> {
delegate! {
to self.emit_context {
/// Returns the inkwell [Context].
pub fn iw_context(&self) -> &'c Context;
/// Returns the internal [CodegenExtsMap] .
pub fn extensions(&self) -> Rc<CodegenExtsMap<'static,H>>;
pub fn extensions(&self) -> Rc<CodegenExtsMap<'a,H>>;
/// Returns a new [TypingSession].
pub fn typing_session(&self) -> TypingSession<'c>;
pub fn typing_session(&self) -> TypingSession<'c, 'a>;
/// Convert hugr [HugrType] into an LLVM [Type](BasicTypeEnum).
pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
/// Convert a [HugrFuncType] into an LLVM [FunctionType].
Expand Down Expand Up @@ -143,9 +144,9 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// TODO on failure return `emit_context`
pub fn new(
emit_context: EmitModuleContext<'c, H>,
emit_context: EmitModuleContext<'c, 'a, H>,
func: FunctionValue<'c>,
) -> Result<EmitFuncContext<'c, H>> {
) -> Result<EmitFuncContext<'c, 'a, H>> {
if func.get_first_basic_block().is_some() {
Err(anyhow!(
"EmitContext::new: Function already has a basic block: {:?}",
Expand Down Expand Up @@ -180,9 +181,9 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
/// Create a new anonymous [RowMailBox]. This mailbox is not mapped to any
/// [Wire]s, and so will not interact with any mailboxes returned from
/// [EmitFuncContext::node_ins_rmb] or [EmitFuncContext::node_outs_rmb].
pub fn new_row_mail_box<'a>(
pub fn new_row_mail_box<'t>(
&mut self,
ts: impl IntoIterator<Item = &'a Type>,
ts: impl IntoIterator<Item = &'t Type>,
name: impl AsRef<str>,
) -> Result<RowMailBox<'c>> {
Ok(RowMailBox::new(
Expand Down Expand Up @@ -307,7 +308,7 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {

/// Consumes the `EmitFuncContext` and returns both the inner
/// [EmitModuleContext] and the scoped [FuncDefn]s that were encountered.
pub fn finish(self) -> Result<(EmitModuleContext<'c, H>, EmissionSet)> {
pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
self.builder.position_at_end(self.prologue_bb);
self.builder.build_unconditional_branch(self.launch_bb)?;
Ok((self.emit_context, self.todo))
Expand Down
Loading

0 comments on commit f7c86d5

Please sign in to comment.