diff --git a/scout-interpreter/src/lib.rs b/scout-interpreter/src/lib.rs index ecb3766..814a42d 100644 --- a/scout-interpreter/src/lib.rs +++ b/scout-interpreter/src/lib.rs @@ -224,7 +224,7 @@ fn eval_statement<'a>( Ok(Arc::new(Object::Null)) } StmtKind::Func(def) => { - let lit = Object::Fn(def.args.clone(), def.body.clone()); + let lit = Object::Fn(def.params.clone(), def.body.clone()); env.lock().await.set(&def.ident, Arc::new(lit)).await; Ok(Arc::new(Object::Null)) } @@ -315,20 +315,31 @@ fn apply_call<'a>( obj_params.insert(0, obj); } + // Set var before match to avoid deadlock on env let env_res = env.lock().await.get(ident).await; match env_res { Some(obj) => match &*obj { - Object::Fn(idents, block) => { - if idents.len() != obj_params.len() { - return Err(EvalError::InvalidUsage("Non-matching arg counts".into())); - } - + Object::Fn(fn_params, block) => { let mut scope = Env::default(); scope.add_outer(env.clone()).await; - for i in 0..idents.len() { - let id = &idents[i]; - let obj = &obj_params[i]; - scope.set(id, obj.clone()).await; + for (i, fn_param) in fn_params.iter().enumerate() { + let id = &fn_param.ident; + match obj_params.get(i) { + Some(provided) => { + scope.set(id, provided.clone()).await; + } + None => match &fn_param.default { + Some(def) => { + let obj_def = + eval_expression(def, crawler, env.clone(), results.clone()) + .await?; + scope.set(id, obj_def).await; + } + None => { + return Err(EvalError::InvalidFnParams); + } + }, + } } let ev = diff --git a/scout-interpreter/src/object.rs b/scout-interpreter/src/object.rs index 5fc770a..acd6c3c 100644 --- a/scout-interpreter/src/object.rs +++ b/scout-interpreter/src/object.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, fmt::Display, sync::Arc}; -use scout_parser::ast::{Block, Identifier}; +use scout_parser::ast::{Block, FnParam, Identifier}; use serde_json::{json, Value}; #[derive(Debug)] @@ -12,7 +12,7 @@ pub enum Object { List(Vec>), Boolean(bool), Number(f64), - Fn(Vec, Block), + Fn(Vec, Block), Return(Arc), } diff --git a/scout-parser/src/ast.rs b/scout-parser/src/ast.rs index cfdeac3..70563d7 100644 --- a/scout-parser/src/ast.rs +++ b/scout-parser/src/ast.rs @@ -76,13 +76,29 @@ pub struct ElseLiteral { #[derive(Debug, PartialEq, Clone)] pub struct FuncDef { pub ident: Identifier, - pub args: Vec, + pub params: Vec, pub body: Block, } impl FuncDef { - pub fn new(ident: Identifier, args: Vec, body: Block) -> Self { - Self { ident, args, body } + pub fn new(ident: Identifier, params: Vec, body: Block) -> Self { + Self { + ident, + params, + body, + } + } +} + +#[derive(Debug, PartialEq, Clone)] +pub struct FnParam { + pub ident: Identifier, + pub default: Option, +} + +impl FnParam { + pub fn new(ident: Identifier, default: Option) -> Self { + Self { ident, default } } } diff --git a/scout-parser/src/lib.rs b/scout-parser/src/lib.rs index 0e38b1b..ffbc17a 100644 --- a/scout-parser/src/lib.rs +++ b/scout-parser/src/lib.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use ast::{ - ExprKind, ForLoop, FuncDef, HashLiteral, Identifier, IfElseLiteral, IfLiteral, Program, - StmtKind, + ExprKind, FnParam, ForLoop, FuncDef, HashLiteral, Identifier, IfElseLiteral, IfLiteral, + Program, StmtKind, }; use scout_lexer::{Lexer, Token, TokenKind}; @@ -17,6 +17,7 @@ pub enum ParseError { UnexpectedToken(TokenKind, TokenKind), InvalidToken(TokenKind), InvalidNumber, + DefaultFnParamBefore, } pub struct Parser { @@ -64,12 +65,26 @@ impl Parser { self.expect_peek(TokenKind::LParen)?; let mut args = Vec::new(); + let mut has_defaults = false; while self.peek.kind == TokenKind::Comma || self.peek.kind != TokenKind::RParen { self.next_token(); match self.curr.kind { TokenKind::Comma => {} TokenKind::Ident => { - args.push(Identifier::new(self.curr.literal.clone())); + let ident = Identifier::new(self.curr.literal.clone()); + let mut default = None; + if self.peek.kind == TokenKind::Assign { + self.next_token(); + self.next_token(); + default = Some(self.parse_expr()?); + has_defaults = true; + } else if has_defaults { + // Dont allow non-default params after default params. + // If we dont disallow this then the interpreter will have a + // hard time + return Err(ParseError::DefaultFnParamBefore); + } + args.push(FnParam::new(ident, default)); } _ => { return Err(ParseError::InvalidToken(self.curr.kind)); @@ -470,6 +485,7 @@ mod tests { ExprKind::Str("a".into()) ) )] + #[test_case(r#"null"#, StmtKind::Expr(ExprKind::Null))] #[test_case( r#"for node in $$"a" do scrape {} end"#, StmtKind::ForLoop( @@ -514,8 +530,20 @@ mod tests { FuncDef::new( Identifier::new("f".into()), vec![ - Identifier::new("a".into()), - Identifier::new("b".into()) + FnParam::new(Identifier::new("a".into()), None), + FnParam::new(Identifier::new("b".into()), None) + ], + Block::default() + ) + ) + )] + #[test_case( + r#"def f(a = null) do end"#, + StmtKind::Func( + FuncDef::new( + Identifier::new("f".into()), + vec![ + FnParam::new(Identifier::new("a".into()), Some(ExprKind::Null)) ], Block::default() )