Skip to content

Commit

Permalink
env func list emitting
Browse files Browse the repository at this point in the history
  • Loading branch information
utkn committed Dec 16, 2024
1 parent 3d72301 commit c81e159
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
48 changes: 32 additions & 16 deletions src/lean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ use crate::{

/// The stringified Lean definitions corresponding to a Noir module.
pub struct ModuleEntries {
pub impl_refs: Vec<String>,
pub impl_refs: HashSet<String>,
pub func_refs: HashSet<String>,
pub defs: Vec<String>,
}

Expand Down Expand Up @@ -108,6 +109,7 @@ impl LeanEmitter {
let mut indenter = Indenter::default();
let mut output = Vec::new();
let mut all_impl_refs = HashSet::new();
let mut all_func_refs = HashSet::new();

// Emit definitions for each of the modules in the context in an arbitrary
// iteration order
Expand All @@ -117,9 +119,14 @@ impl LeanEmitter {
.expect("Root crate was missing in compilation context")
.modules()
{
let ModuleEntries { impl_refs, defs } = self.emit_module(&mut indenter, module)?;
let ModuleEntries {
impl_refs,
func_refs,
defs,
} = self.emit_module(&mut indenter, module)?;
output.extend(defs);
all_impl_refs.extend(impl_refs);
all_func_refs.extend(func_refs);
}

// Remove all entries that are duplicated as we do not necessarily have the
Expand All @@ -129,7 +136,10 @@ impl LeanEmitter {
let no_dupes: Vec<String> = set.into_iter().collect();
let module_defs = no_dupes.join("\n");

let env_funcs = "";
let env_funcs = all_func_refs
.into_iter()
.map(|r| format!("({r}.name, {r}.fn)"))
.join(", ");
let env_traits = all_impl_refs.into_iter().join(", ");
let env_def = format!("def env := Lampe.Env.mk [{env_funcs}] [{env_traits}]");

Expand All @@ -147,7 +157,7 @@ impl LeanEmitter {
let mut accumulator = Vec::new();

// We start by emitting the trait implementations.
let mut impl_ids = Vec::new();
let mut impl_refs = HashSet::new();
for (id, trait_impl) in self
.context
.def_interner
Expand All @@ -158,9 +168,10 @@ impl LeanEmitter {
let impl_id = format!("impl_{}", id.0);
let trait_impl = self.emit_trait_impl(ind, &trait_impl.borrow(), &impl_id)?;
accumulator.push(trait_impl);
impl_ids.push(impl_id);
impl_refs.insert(impl_id);
}

let mut func_refs = HashSet::new();
// We then emit all definitions that correspond to the given module.
for typedef in module.type_definitions().chain(module.value_definitions()) {
let definition = match typedef {
Expand All @@ -169,7 +180,13 @@ impl LeanEmitter {
if self.context.function_meta(&id).trait_impl.is_some() {
continue;
}
self.emit_free_function_def(ind, id)?
let (def_name, def) = self.emit_free_function_def(ind, id)?;
// [TODO] fix
if def_name.starts_with("_") {
continue;
}
func_refs.insert(format!("«{def_name}»"));
def
}
ModuleDefId::TypeId(id) => self.emit_struct_def(ind, id)?,
ModuleDefId::GlobalId(id) => self.emit_global(ind, id)?,
Expand All @@ -184,15 +201,16 @@ impl LeanEmitter {
accumulator.push(definition.to_string());
}

let defns = accumulator
let defs = accumulator
.into_iter()
.filter(|d| !d.is_empty())
.map(|d| format!("{d}\n"))
.collect();

Ok(ModuleEntries {
impl_refs: impl_ids,
defs: defns,
impl_refs,
func_refs,
defs,
})
}

Expand Down Expand Up @@ -404,7 +422,11 @@ impl LeanEmitter {
/// # Errors
///
/// - [`Error`] if the extraction process fails for any reason.
pub fn emit_free_function_def(&self, ind: &mut Indenter, func: FuncId) -> Result<String> {
pub fn emit_free_function_def(
&self,
ind: &mut Indenter,
func: FuncId,
) -> Result<(String, String)> {
// Get the various parameters
let func_data = self.context.function_meta(&func);
let fq_path = self
Expand All @@ -430,12 +452,6 @@ impl LeanEmitter {
let fn_ident = format!("{self_type_str}{fq_path}");

// [TODO] discard the dummy trait methods
/* if result.contains(&format!(" _::")) {
// This is a dummy trait method that we don't care about, so we discard it.
Ok(String::new())
} else {
Ok(result)
} */

// Now we can actually build our function
Ok(syntax::format_free_function_def(
Expand Down
14 changes: 7 additions & 7 deletions src/lean/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ fn normalize_ident(ident: &str) -> String {
ident.split("::").map(|p| without_generic_args(p)).join("::")
}

#[inline]
pub(super) fn format_free_function_def(
func_ident: &str,
def_generics: &str,
params: &str,
ret_type: &str,
body: &str,
) -> String {
) -> (String, String) {
let func_ident = normalize_ident(func_ident);
formatdoc! {
r"nr_def {func_ident}<{def_generics}> ({params}) -> {ret_type} {{
(
func_ident.clone(),
formatdoc! {
r"nr_def {func_ident}<{def_generics}>({params}) -> {ret_type} {{
{body}
}}"
}
},
)
}

pub(super) fn format_trait_function_def(
Expand Down Expand Up @@ -89,7 +91,6 @@ pub(super) mod expr {
format!("{struct_ident}<{struct_generic_vals}> {{ {fields_ordered} }}")
}

#[inline]
pub fn format_call(func_expr: &str, func_args: &str, out_ty: &str, is_lambda: bool) -> String {
if is_lambda {
format!("(^{func_expr}({func_args}) : {out_ty})")
Expand Down Expand Up @@ -149,7 +150,6 @@ pub(super) mod expr {
normalize_ident(ident)
}

#[inline]
pub fn format_func_ident(ident: &str, generics: &str, is_builtin: bool) -> String {
let ident = normalize_ident(ident);
if is_builtin {
Expand Down

0 comments on commit c81e159

Please sign in to comment.