From 6c56cc0dafc05a31eb192c341d658302a495c1c5 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 11 Apr 2024 18:21:44 +0800 Subject: [PATCH 1/2] refactor: builtins --- src/imports.rs | 5 ---- src/name_resolution/expr.rs | 5 ++-- src/parser/types.rs | 4 +-- src/stdlib/crypto.rs | 32 ++++++++++++++++++-- src/stdlib/mod.rs | 58 ++++++++++++++++++++++++++++--------- src/type_checker/mod.rs | 19 +++++------- 6 files changed, 85 insertions(+), 38 deletions(-) diff --git a/src/imports.rs b/src/imports.rs index f82dc909d..7ac98352f 100644 --- a/src/imports.rs +++ b/src/imports.rs @@ -1,6 +1,5 @@ use std::{collections::HashMap, fmt}; -use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use crate::{ @@ -8,7 +7,6 @@ use crate::{ constants::Span, error::Result, parser::types::{FnSig, FunctionDef}, - stdlib::{parse_fn_sigs, BUILTIN_FNS_DEFS}, type_checker::{FnInfo, TypeChecker}, var::Var, }; @@ -73,6 +71,3 @@ impl fmt::Debug for FnKind { } } -// static of built-in functions -pub static BUILTIN_FNS: Lazy> = - Lazy::new(|| parse_fn_sigs(&BUILTIN_FNS_DEFS)); diff --git a/src/name_resolution/expr.rs b/src/name_resolution/expr.rs index 7318b5434..d45ea2873 100644 --- a/src/name_resolution/expr.rs +++ b/src/name_resolution/expr.rs @@ -1,9 +1,8 @@ use crate::{ cli::packages::UserRepo, error::Result, - imports::BUILTIN_FNS, parser::{types::ModulePath, CustomType, Expr, ExprKind}, - stdlib::QUALIFIED_BUILTINS, + stdlib::{BUILTIN_FN_NAMES, QUALIFIED_BUILTINS}, }; use super::context::NameResCtx; @@ -22,7 +21,7 @@ impl NameResCtx { fn_name, args, } => { - if matches!(module, ModulePath::Local) && BUILTIN_FNS.get(&fn_name.value).is_some() + if matches!(module, ModulePath::Local) && BUILTIN_FN_NAMES.contains(&fn_name.value) { // if it's a builtin, use `std::builtin` *module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS)); diff --git a/src/parser/types.rs b/src/parser/types.rs index dfd5ab221..d916ea66b 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -6,8 +6,8 @@ use crate::{ cli::packages::UserRepo, constants::{Field, Span}, error::{ErrorKind, Result}, - imports::BUILTIN_FNS, lexer::{Keyword, Token, TokenKind, Tokens}, + stdlib::BUILTIN_FN_NAMES, syntax::is_type, }; @@ -785,7 +785,7 @@ impl FunctionDef { let sig = FnSig::parse(ctx, tokens)?; // make sure that it doesn't shadow a builtin - if BUILTIN_FNS.get(&sig.name.value).is_some() { + if BUILTIN_FN_NAMES.contains(&sig.name.value) { return Err(ctx.error( ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), sig.name.span, diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index e244f651d..45f51c97b 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -3,18 +3,46 @@ use kimchi::circuits::polynomials::poseidon::{POS_ROWS_PER_HASH, ROUNDS_PER_ROW} use kimchi::mina_poseidon::constants::{PlonkSpongeConstantsKimchi, SpongeConstants}; use kimchi::mina_poseidon::permutation::full_round; +use crate::imports::FnKind; +use crate::lexer::Token; +use crate::parser::types::FnSig; +use crate::parser::ParserCtx; +use crate::type_checker::FnInfo; use crate::{ circuit_writer::{CircuitWriter, GateKind, VarInfo}, constants::{self, Field, Span}, error::{ErrorKind, Result}, - imports::FnHandle, parser::types::TyKind, var::{ConstOrCell, Value, Var}, }; const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]"; -pub const CRYPTO_FNS: [(&str, FnHandle); 1] = [(POSEIDON_FN, poseidon)]; +pub const CRYPTO_SIGS: &[&str] = &[POSEIDON_FN]; + +pub fn get_crypto_fn(name: &str) -> Option { + let ctx = &mut ParserCtx::default(); + let mut tokens = Token::parse(0, name).unwrap(); + let sig = FnSig::parse(ctx, &mut tokens).unwrap(); + + let fn_handle = match name { + POSEIDON_FN => poseidon, + _ => return None, + }; + + Some(FnInfo { + kind: FnKind::BuiltIn(sig, fn_handle), + span: Span::default(), + }) +} + +/// a function returns crypto functions +pub fn crypto_fns() -> Vec { + CRYPTO_SIGS + .iter() + .map(|sig| get_crypto_fn(sig).unwrap()) + .collect() +} pub fn poseidon(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result> { // diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 9710a2335..e36476f62 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, ops::Neg as _}; +use std::{collections::{HashMap, HashSet}, ops::Neg as _}; use ark_ff::{One as _, Zero}; use once_cell::sync::Lazy; @@ -7,7 +7,7 @@ use crate::{ circuit_writer::{CircuitWriter, VarInfo}, constants::{Field, Span}, error::{Error, ErrorKind, Result}, - imports::{BuiltinModule, FnHandle, FnKind}, + imports::{FnHandle, FnKind}, lexer::Token, parser::{ types::{FnSig, TyKind}, @@ -17,21 +17,13 @@ use crate::{ var::{ConstOrCell, Var}, }; -use self::crypto::CRYPTO_FNS; +use self::crypto::get_crypto_fn; pub mod crypto; -pub static CRYPTO_MODULE: Lazy = Lazy::new(|| { - let functions = parse_fn_sigs(&CRYPTO_FNS); - BuiltinModule { functions } -}); - pub fn get_std_fn(submodule: &str, fn_name: &str, span: Span) -> Result { match submodule { - "crypto" => CRYPTO_MODULE - .functions - .get(fn_name) - .cloned() + "crypto" => get_crypto_fn(fn_name) .ok_or_else(|| { Error::new( "type-checker", @@ -81,8 +73,46 @@ pub const QUALIFIED_BUILTINS: &str = "std/builtins"; const ASSERT_FN: &str = "assert(condition: Bool)"; const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -pub const BUILTIN_FNS_DEFS: [(&str, FnHandle); 2] = - [(ASSERT_EQ_FN, assert_eq), (ASSERT_FN, assert)]; +/// List of builtin function signatures. +pub const BUILTIN_SIGS: &[&str] = &[ASSERT_FN, ASSERT_EQ_FN]; + +// Unique set of builtin function names, derived from function signatures. +pub static BUILTIN_FN_NAMES: Lazy> = Lazy::new(|| { + BUILTIN_SIGS + .iter() + .map(|s| { + let ctx = &mut ParserCtx::default(); + let mut tokens = Token::parse(0, s).unwrap(); + let sig = FnSig::parse(ctx, &mut tokens).unwrap(); + sig.name.value + }) + .collect() +}); + +pub fn get_builtin_fn(name: &str) -> Option { + let ctx = &mut ParserCtx::default(); + let mut tokens = Token::parse(0, name).unwrap(); + let sig = FnSig::parse(ctx, &mut tokens).unwrap(); + + let fn_handle = match name { + ASSERT_FN => assert, + ASSERT_EQ_FN => assert_eq, + _ => return None, + }; + + Some(FnInfo { + kind: FnKind::BuiltIn(sig, fn_handle), + span: Span::default(), + }) +} + +/// a function returns builtin functions +pub fn builtin_fns() -> Vec { + BUILTIN_SIGS + .iter() + .map(|sig| get_builtin_fn(sig).unwrap()) + .collect() +} /// Asserts that two vars are equal. fn assert_eq(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result> { diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 90d2ef033..7e46378f9 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -4,13 +4,13 @@ use crate::{ cli::packages::UserRepo, constants::{Field, Span}, error::{Error, ErrorKind, Result}, - imports::{FnKind, BUILTIN_FNS}, + imports::FnKind, name_resolution::NAST, parser::{ types::{FuncOrMethod, FunctionDef, ModulePath, RootKind, Ty, TyKind}, CustomType, Expr, StructDef, }, - stdlib::{CRYPTO_MODULE, QUALIFIED_BUILTINS}, + stdlib::{builtin_fns, crypto::crypto_fns, QUALIFIED_BUILTINS}, }; pub use checker::{FnInfo, StructInfo}; @@ -91,12 +91,7 @@ impl TypeChecker { } pub(crate) fn fn_info(&self, qualified: &FullyQualified) -> Option<&FnInfo> { - if qualified.module == Some(UserRepo::new("std/builtins")) { - // if it's a built-in: get it from a global - BUILTIN_FNS.get(&qualified.name) - } else { - self.functions.get(qualified) - } + self.functions.get(qualified) } pub(crate) fn const_info(&self, qualified: &FullyQualified) -> Option<&ConstInfo> { @@ -141,8 +136,8 @@ impl TypeChecker { // initialize it with the builtins let builtin_module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS)); - for (fn_name, fn_info) in BUILTIN_FNS.iter() { - let qualified = FullyQualified::new(&builtin_module, fn_name); + for fn_info in builtin_fns().iter() { + let qualified = FullyQualified::new(&builtin_module, &fn_info.sig().name.value); if type_checker .functions .insert(qualified, fn_info.clone()) @@ -154,8 +149,8 @@ impl TypeChecker { // initialize it with the standard library let crypto_module = ModulePath::Absolute(UserRepo::new("std/crypto")); - for (fn_name, fn_info) in CRYPTO_MODULE.functions.iter() { - let qualified = FullyQualified::new(&crypto_module, fn_name); + for fn_info in crypto_fns().iter() { + let qualified = FullyQualified::new(&crypto_module, &fn_info.sig().name.value); if type_checker .functions .insert(qualified, fn_info.clone()) From fc65b02433bd3c4a83fd4127b3ee5477e80cfa43 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 12 Apr 2024 09:38:26 +0800 Subject: [PATCH 2/2] remove dead code --- src/stdlib/mod.rs | 42 ----------------------------------------- src/type_checker/mod.rs | 4 ++-- 2 files changed, 2 insertions(+), 44 deletions(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index e36476f62..9a46a08f6 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -21,48 +21,6 @@ use self::crypto::get_crypto_fn; pub mod crypto; -pub fn get_std_fn(submodule: &str, fn_name: &str, span: Span) -> Result { - match submodule { - "crypto" => get_crypto_fn(fn_name) - .ok_or_else(|| { - Error::new( - "type-checker", - ErrorKind::UnknownExternalFn(submodule.to_string(), fn_name.to_string()), - span, - ) - }), - _ => Err(Error::new( - "type-checker", - ErrorKind::StdImport(submodule.to_string()), - span, - )), - } -} - -/// Takes a list of function signatures (as strings) and their associated function pointer, -/// returns the same list but with the parsed functions (as [FunctionSig]). -pub fn parse_fn_sigs(fn_sigs: &[(&str, FnHandle)]) -> HashMap { - let mut functions = HashMap::new(); - let ctx = &mut ParserCtx::default(); - - for (sig, fn_ptr) in fn_sigs { - // filename_id 0 is for builtins - let mut tokens = Token::parse(0, sig).unwrap(); - - let sig = FnSig::parse(ctx, &mut tokens).unwrap(); - - functions.insert( - sig.name.value.clone(), - FnInfo { - kind: FnKind::BuiltIn(sig, *fn_ptr), - span: Span::default(), - }, - ); - } - - functions -} - // // Builtins or utils (imported by default) // TODO: give a name that's useful for the user, diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 7e46378f9..04607f15e 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -136,7 +136,7 @@ impl TypeChecker { // initialize it with the builtins let builtin_module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS)); - for fn_info in builtin_fns().iter() { + for fn_info in builtin_fns() { let qualified = FullyQualified::new(&builtin_module, &fn_info.sig().name.value); if type_checker .functions @@ -149,7 +149,7 @@ impl TypeChecker { // initialize it with the standard library let crypto_module = ModulePath::Absolute(UserRepo::new("std/crypto")); - for fn_info in crypto_fns().iter() { + for fn_info in crypto_fns() { let qualified = FullyQualified::new(&crypto_module, &fn_info.sig().name.value); if type_checker .functions