Skip to content

Commit

Permalink
Merge pull request #33 from katat/refactor/builtins
Browse files Browse the repository at this point in the history
refactor: builtins
  • Loading branch information
katat authored Apr 12, 2024
2 parents 9b0b773 + fc65b02 commit 3d1c92a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 79 deletions.
5 changes: 0 additions & 5 deletions src/imports.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::{collections::HashMap, fmt};

use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};

use crate::{
circuit_writer::{CircuitWriter, VarInfo},
constants::Span,
error::Result,
parser::types::{FnSig, FunctionDef},
stdlib::{parse_fn_sigs, BUILTIN_FNS_DEFS},
type_checker::{FnInfo, TypeChecker},
var::Var,
};
Expand Down Expand Up @@ -73,6 +71,3 @@ impl fmt::Debug for FnKind {
}
}

// static of built-in functions
pub static BUILTIN_FNS: Lazy<HashMap<String, FnInfo>> =
Lazy::new(|| parse_fn_sigs(&BUILTIN_FNS_DEFS));
5 changes: 2 additions & 3 deletions src/name_resolution/expr.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::{
cli::packages::UserRepo,
error::Result,
imports::BUILTIN_FNS,
parser::{types::ModulePath, CustomType, Expr, ExprKind},
stdlib::QUALIFIED_BUILTINS,
stdlib::{BUILTIN_FN_NAMES, QUALIFIED_BUILTINS},
};

use super::context::NameResCtx;
Expand All @@ -22,7 +21,7 @@ impl NameResCtx {
fn_name,
args,
} => {
if matches!(module, ModulePath::Local) && BUILTIN_FNS.get(&fn_name.value).is_some()
if matches!(module, ModulePath::Local) && BUILTIN_FN_NAMES.contains(&fn_name.value)
{
// if it's a builtin, use `std::builtin`
*module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
Expand Down
4 changes: 2 additions & 2 deletions src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
cli::packages::UserRepo,
constants::{Field, Span},
error::{ErrorKind, Result},
imports::BUILTIN_FNS,
lexer::{Keyword, Token, TokenKind, Tokens},
stdlib::BUILTIN_FN_NAMES,
syntax::is_type,
};

Expand Down Expand Up @@ -785,7 +785,7 @@ impl FunctionDef {
let sig = FnSig::parse(ctx, tokens)?;

// make sure that it doesn't shadow a builtin
if BUILTIN_FNS.get(&sig.name.value).is_some() {
if BUILTIN_FN_NAMES.contains(&sig.name.value) {
return Err(ctx.error(
ErrorKind::ShadowingBuiltIn(sig.name.value.clone()),
sig.name.span,
Expand Down
32 changes: 30 additions & 2 deletions src/stdlib/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,46 @@ use kimchi::circuits::polynomials::poseidon::{POS_ROWS_PER_HASH, ROUNDS_PER_ROW}
use kimchi::mina_poseidon::constants::{PlonkSpongeConstantsKimchi, SpongeConstants};
use kimchi::mina_poseidon::permutation::full_round;

use crate::imports::FnKind;
use crate::lexer::Token;
use crate::parser::types::FnSig;
use crate::parser::ParserCtx;
use crate::type_checker::FnInfo;
use crate::{
circuit_writer::{CircuitWriter, GateKind, VarInfo},
constants::{self, Field, Span},
error::{ErrorKind, Result},
imports::FnHandle,
parser::types::TyKind,
var::{ConstOrCell, Value, Var},
};

const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]";

pub const CRYPTO_FNS: [(&str, FnHandle); 1] = [(POSEIDON_FN, poseidon)];
pub const CRYPTO_SIGS: &[&str] = &[POSEIDON_FN];

pub fn get_crypto_fn(name: &str) -> Option<FnInfo> {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, name).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

let fn_handle = match name {
POSEIDON_FN => poseidon,
_ => return None,
};

Some(FnInfo {
kind: FnKind::BuiltIn(sig, fn_handle),
span: Span::default(),
})
}

/// a function returns crypto functions
pub fn crypto_fns() -> Vec<FnInfo> {
CRYPTO_SIGS
.iter()
.map(|sig| get_crypto_fn(sig).unwrap())
.collect()
}

pub fn poseidon(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result<Option<Var>> {
//
Expand Down
98 changes: 43 additions & 55 deletions src/stdlib/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, ops::Neg as _};
use std::{collections::{HashMap, HashSet}, ops::Neg as _};

use ark_ff::{One as _, Zero};
use once_cell::sync::Lazy;
Expand All @@ -7,7 +7,7 @@ use crate::{
circuit_writer::{CircuitWriter, VarInfo},
constants::{Field, Span},
error::{Error, ErrorKind, Result},
imports::{BuiltinModule, FnHandle, FnKind},
imports::{FnHandle, FnKind},
lexer::Token,
parser::{
types::{FnSig, TyKind},
Expand All @@ -17,60 +17,10 @@ use crate::{
var::{ConstOrCell, Var},
};

use self::crypto::CRYPTO_FNS;
use self::crypto::get_crypto_fn;

pub mod crypto;

pub static CRYPTO_MODULE: Lazy<BuiltinModule> = Lazy::new(|| {
let functions = parse_fn_sigs(&CRYPTO_FNS);
BuiltinModule { functions }
});

pub fn get_std_fn(submodule: &str, fn_name: &str, span: Span) -> Result<FnInfo> {
match submodule {
"crypto" => CRYPTO_MODULE
.functions
.get(fn_name)
.cloned()
.ok_or_else(|| {
Error::new(
"type-checker",
ErrorKind::UnknownExternalFn(submodule.to_string(), fn_name.to_string()),
span,
)
}),
_ => Err(Error::new(
"type-checker",
ErrorKind::StdImport(submodule.to_string()),
span,
)),
}
}

/// Takes a list of function signatures (as strings) and their associated function pointer,
/// returns the same list but with the parsed functions (as [FunctionSig]).
pub fn parse_fn_sigs(fn_sigs: &[(&str, FnHandle)]) -> HashMap<String, FnInfo> {
let mut functions = HashMap::new();
let ctx = &mut ParserCtx::default();

for (sig, fn_ptr) in fn_sigs {
// filename_id 0 is for builtins
let mut tokens = Token::parse(0, sig).unwrap();

let sig = FnSig::parse(ctx, &mut tokens).unwrap();

functions.insert(
sig.name.value.clone(),
FnInfo {
kind: FnKind::BuiltIn(sig, *fn_ptr),
span: Span::default(),
},
);
}

functions
}

//
// Builtins or utils (imported by default)
// TODO: give a name that's useful for the user,
Expand All @@ -81,8 +31,46 @@ pub const QUALIFIED_BUILTINS: &str = "std/builtins";
const ASSERT_FN: &str = "assert(condition: Bool)";
const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)";

pub const BUILTIN_FNS_DEFS: [(&str, FnHandle); 2] =
[(ASSERT_EQ_FN, assert_eq), (ASSERT_FN, assert)];
/// List of builtin function signatures.
pub const BUILTIN_SIGS: &[&str] = &[ASSERT_FN, ASSERT_EQ_FN];

// Unique set of builtin function names, derived from function signatures.
pub static BUILTIN_FN_NAMES: Lazy<HashSet<String>> = Lazy::new(|| {
BUILTIN_SIGS
.iter()
.map(|s| {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, s).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();
sig.name.value
})
.collect()
});

pub fn get_builtin_fn(name: &str) -> Option<FnInfo> {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, name).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

let fn_handle = match name {
ASSERT_FN => assert,
ASSERT_EQ_FN => assert_eq,
_ => return None,
};

Some(FnInfo {
kind: FnKind::BuiltIn(sig, fn_handle),
span: Span::default(),
})
}

/// a function returns builtin functions
pub fn builtin_fns() -> Vec<FnInfo> {
BUILTIN_SIGS
.iter()
.map(|sig| get_builtin_fn(sig).unwrap())
.collect()
}

/// Asserts that two vars are equal.
fn assert_eq(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result<Option<Var>> {
Expand Down
19 changes: 7 additions & 12 deletions src/type_checker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::{
cli::packages::UserRepo,
constants::{Field, Span},
error::{Error, ErrorKind, Result},
imports::{FnKind, BUILTIN_FNS},
imports::FnKind,
name_resolution::NAST,
parser::{
types::{FuncOrMethod, FunctionDef, ModulePath, RootKind, Ty, TyKind},
CustomType, Expr, StructDef,
},
stdlib::{CRYPTO_MODULE, QUALIFIED_BUILTINS},
stdlib::{builtin_fns, crypto::crypto_fns, QUALIFIED_BUILTINS},
};

pub use checker::{FnInfo, StructInfo};
Expand Down Expand Up @@ -91,12 +91,7 @@ impl TypeChecker {
}

pub(crate) fn fn_info(&self, qualified: &FullyQualified) -> Option<&FnInfo> {
if qualified.module == Some(UserRepo::new("std/builtins")) {
// if it's a built-in: get it from a global
BUILTIN_FNS.get(&qualified.name)
} else {
self.functions.get(qualified)
}
self.functions.get(qualified)
}

pub(crate) fn const_info(&self, qualified: &FullyQualified) -> Option<&ConstInfo> {
Expand Down Expand Up @@ -141,8 +136,8 @@ impl TypeChecker {

// initialize it with the builtins
let builtin_module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
for (fn_name, fn_info) in BUILTIN_FNS.iter() {
let qualified = FullyQualified::new(&builtin_module, fn_name);
for fn_info in builtin_fns() {
let qualified = FullyQualified::new(&builtin_module, &fn_info.sig().name.value);
if type_checker
.functions
.insert(qualified, fn_info.clone())
Expand All @@ -154,8 +149,8 @@ impl TypeChecker {

// initialize it with the standard library
let crypto_module = ModulePath::Absolute(UserRepo::new("std/crypto"));
for (fn_name, fn_info) in CRYPTO_MODULE.functions.iter() {
let qualified = FullyQualified::new(&crypto_module, fn_name);
for fn_info in crypto_fns() {
let qualified = FullyQualified::new(&crypto_module, &fn_info.sig().name.value);
if type_checker
.functions
.insert(qualified, fn_info.clone())
Expand Down

0 comments on commit 3d1c92a

Please sign in to comment.