From a829f65007e346fefb3becd53ab1d1c733e683c6 Mon Sep 17 00:00:00 2001 From: Kajetan Puchalski Date: Tue, 27 Aug 2024 00:46:56 +0100 Subject: [PATCH] ast: Parsing function declarations, definitions & calls --- samples/functions.c | 12 ++ src/ast.rs | 333 +++++++++++++++++++++++++++++++------------- src/grammar.txt | 21 ++- src/lexer.rs | 4 + 4 files changed, 267 insertions(+), 103 deletions(-) create mode 100644 samples/functions.c diff --git a/samples/functions.c b/samples/functions.c new file mode 100644 index 0000000..62679fa --- /dev/null +++ b/samples/functions.c @@ -0,0 +1,12 @@ +int sum(int first, int second); + +int sum(int a, int b) { + return a + b; +} + +int main(void) { + int one = 5; + int two = 6; + int result = sum(one, two); + return result; +} diff --git a/src/ast.rs b/src/ast.rs index c7d739d..e449a67 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -5,112 +5,140 @@ use std::mem::discriminant; use std::{error::Error, fmt}; use strum_macros::{Display, EnumIs}; -#[inline(always)] -fn log_trace(msg: &str, tokens: &mut VecDeque) { - log::trace!( - "{} {:?}", - msg, - tokens.iter().take(4).collect::>() - ); -} - -#[derive(Debug)] -pub enum ParserError { - UnexpectedToken, - NoTokens, - MalformedExpression, - IdentifierParsingError, -} - -impl fmt::Display for ParserError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self) - } -} - -impl Error for ParserError {} - -#[inline(always)] -fn expect_token_silent( - expected: TokenKind, - tokens: &mut VecDeque, -) -> Result { - let exp = discriminant(&expected); - let actual = discriminant(&tokens[0]); - - if actual != exp { - Err(ParserError::UnexpectedToken) - } else { - Ok(tokens.pop_front().unwrap()) - } -} - -#[inline(always)] -fn expect_token( - expected: TokenKind, - tokens: &mut VecDeque, -) -> Result { - let result = expect_token_silent(expected.clone(), tokens); - if let Err(_) = result { - log::error!( - "Syntax Error: Expected {:?}, got {:?}", - &expected, - &tokens[0] - ); - } - result -} - -#[derive(Debug, PartialEq, DisplayTree)] +#[derive(Debug, PartialEq)] #[allow(dead_code)] pub struct Program { - #[tree] - pub body: Function, + pub body: Vec, } impl Program { pub fn parse(tokens: Vec) -> Program { let mut tokens = VecDeque::from(tokens); + let mut body = vec![]; - // TODO: better error handling here - let program = Program { - body: Function::parse(&mut tokens).unwrap(), - }; - - if !tokens.is_empty() { - panic!("Syntax Error: Unexpected token {:?}", tokens[0]); + while !tokens.is_empty() { + match FunctionDeclaration::parse(&mut tokens) { + Ok(func) => body.push(func), + Err(err) => { + log::error!("Error reason: {}", err); + panic!("Could not parse AST"); + } + } } - program + Program { body } } } -#[derive(Debug, PartialEq, Clone, DisplayTree)] +impl DisplayTree for Program { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>, style: display_tree::Style) -> std::fmt::Result { + for function in &self.body { + writeln!( + f, + "{}{} {}", + style.char_set.connector, + std::iter::repeat(style.char_set.horizontal) + .take(style.indentation as usize) + .collect::(), + display_tree::format_tree!(*function) + )?; + } + Ok(()) + } +} + +#[derive(Debug, PartialEq, Clone)] #[allow(dead_code)] -pub struct Function { +pub struct FunctionDeclaration { pub name: Identifier, + pub params: Vec, pub return_type: String, - #[tree] - pub body: Block, + pub body: Option, } -impl Function { - fn parse(tokens: &mut VecDeque) -> Result { +impl FunctionDeclaration { + fn parse(tokens: &mut VecDeque) -> Result { + log_trace("parsing function from", tokens); let return_type = expect_token(TokenKind::Int, tokens)?; let name = Identifier::parse(expect_token(TokenKind::Identifier("".to_owned()), tokens)?)?; expect_token(TokenKind::ParenOpen, tokens)?; - expect_token(TokenKind::Void, tokens)?; + + let mut params = vec![]; + + // Parse parameters if any are present + while !tokens.front().unwrap().is_paren_close() && !tokens.front().unwrap().is_void() { + if tokens.front().unwrap().is_comma() { + expect_token(TokenKind::Comma, tokens)?; + } + let mut param_result = || -> Result { + expect_token(TokenKind::Int, tokens)?; + Identifier::parse(expect_token(TokenKind::Identifier("".to_owned()), tokens)?) + }; + + if let Ok(ident) = param_result() { + params.push(ident); + } else { + return Err(ParserError::TrailingComma); + } + } + + if params.len() == 0 { + expect_token(TokenKind::Void, tokens)?; + } expect_token(TokenKind::ParenClose, tokens)?; - let body = Block::parse(tokens)?; + let body = if tokens.front().unwrap().is_brace_open() { + Some(Block::parse(tokens)?) + } else { + expect_token(TokenKind::Semicolon, tokens)?; + None + }; - Ok(Function { + let func = FunctionDeclaration { name, + params, return_type: return_type.to_string(), body, - }) + }; + + log::trace!("--- Parsed function declaration: {}", func); + Ok(func) + } +} + +impl fmt::Display for FunctionDeclaration { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let params = self + .params + .iter() + .map(|p| p.to_string()) + .collect::>() + .join(","); + writeln!(f, "{} {}({})", self.return_type, self.name, params)?; + if let Some(body) = &self.body { + writeln!(f, "\t{}", body)?; + } + Ok(()) + } +} + +impl DisplayTree for FunctionDeclaration { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>, style: display_tree::Style) -> std::fmt::Result { + writeln!(f, "{} {} ({:?})", self.return_type, self.name, self.params)?; + if let Some(body) = &self.body { + writeln!( + f, + "{}{} {}", + style.char_set.connector, + std::iter::repeat(style.char_set.horizontal) + .take(style.indentation as usize) + .collect::(), + display_tree::format_tree!(*body) + )?; + } + Ok(()) } } @@ -175,7 +203,6 @@ impl BlockItem { let token = tokens.front().unwrap().to_owned(); Ok(if token.is_int() { let decl = Declaration::parse(tokens)?; - expect_token(TokenKind::Semicolon, tokens)?; BlockItem::Decl(decl) } else { BlockItem::Stmt(Statement::parse(tokens)?) @@ -192,14 +219,39 @@ impl fmt::Display for BlockItem { } } +#[derive(Debug, PartialEq, Clone, DisplayTree)] +pub enum Declaration { + FunDecl(#[tree] FunctionDeclaration), + VarDecl(#[tree] VariableDeclaration), +} + +impl Declaration { + fn parse(tokens: &mut VecDeque) -> Result { + // If the 3rd token is a '(', we're looking at a function declaration + if tokens.get(2).unwrap().is_paren_open() { + Ok(Self::FunDecl(FunctionDeclaration::parse(tokens)?)) + } else { + Ok(Self::VarDecl(VariableDeclaration::parse(tokens)?)) + } + } +} + +impl fmt::Display for Declaration { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Declaration::FunDecl(decl) => write!(f, "{:?}", decl), + Declaration::VarDecl(decl) => write!(f, "{}", decl), + } + } +} + #[derive(Debug, PartialEq, Clone)] -#[allow(dead_code)] -pub struct Declaration { +pub struct VariableDeclaration { pub name: Identifier, pub init: Option, } -impl Declaration { +impl VariableDeclaration { fn parse(tokens: &mut VecDeque) -> Result { log_trace("Parsing declaration from", tokens); // Silent expect here because we can use this failing to check @@ -215,6 +267,7 @@ impl Declaration { expect_token(TokenKind::Assignment, tokens)?; Some(Expression::parse(tokens, 0)?) }; + expect_token(TokenKind::Semicolon, tokens)?; let result = Self { name: ident, init }; log::trace!("-- Parsed declaration: {}", result); @@ -222,7 +275,7 @@ impl Declaration { } } -impl fmt::Display for Declaration { +impl fmt::Display for VariableDeclaration { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "int {}", self.name)?; if let Some(init) = &self.init { @@ -234,7 +287,7 @@ impl fmt::Display for Declaration { } } -impl DisplayTree for Declaration { +impl DisplayTree for VariableDeclaration { fn fmt(&self, f: &mut std::fmt::Formatter<'_>, style: display_tree::Style) -> std::fmt::Result { writeln!(f, "{}", self.name)?; if let Some(init) = &self.init { @@ -514,7 +567,6 @@ impl Statement { expect_token(TokenKind::For, tokens)?; expect_token(TokenKind::ParenOpen, tokens)?; let init = ForInit::parse(tokens)?; - expect_token(TokenKind::Semicolon, tokens)?; let cond = Expression::parse_optional(tokens)?; expect_token(TokenKind::Semicolon, tokens)?; let post = Expression::parse_optional(tokens)?; @@ -547,20 +599,23 @@ impl Statement { #[derive(Debug, PartialEq, Clone, DisplayTree)] #[allow(dead_code)] pub enum ForInit { - InitDecl(#[tree] Declaration), + InitDecl(#[tree] VariableDeclaration), InitExp(#[tree] Expression), InitNull, } impl ForInit { fn parse(tokens: &mut VecDeque) -> Result { - Ok(if let Ok(decl) = Declaration::parse(tokens) { + let result = if let Ok(decl) = VariableDeclaration::parse(tokens) { Self::InitDecl(decl) } else if let Ok(exp) = Expression::parse(tokens, 0) { + expect_token(TokenKind::Semicolon, tokens)?; Self::InitExp(exp) } else { + expect_token(TokenKind::Semicolon, tokens)?; Self::InitNull - }) + }; + Ok(result) } } @@ -591,6 +646,7 @@ pub enum Expression { #[tree] Box, #[tree] Box, ), + FunctionCall(Identifier, #[ignore_field] Vec), } impl Expression { @@ -642,7 +698,29 @@ impl Expression { let token = tokens.front().unwrap().to_owned(); if token.is_identifier() { - return Ok(Self::Var(Identifier::parse(tokens.pop_front().unwrap())?)); + // If we have a '(' after an identifier, it's a function call + return Ok(if tokens.get(1).unwrap().is_paren_open() { + let name = Identifier::parse(tokens.pop_front().unwrap())?; + expect_token(TokenKind::ParenOpen, tokens)?; + let mut args = vec![]; + + // Parse arguments if any are present + while !tokens.front().unwrap().is_paren_close() { + if tokens.front().unwrap().is_comma() { + expect_token(TokenKind::Comma, tokens)?; + } + if let Ok(exp) = Expression::parse(tokens, 0) { + args.push(exp); + } else { + return Err(ParserError::TrailingComma); + } + } + expect_token(TokenKind::ParenClose, tokens)?; + + Self::FunctionCall(name, args) + } else { + Self::Var(Identifier::parse(tokens.pop_front().unwrap())?) + }); } if token.is_constant() { @@ -687,6 +765,7 @@ impl fmt::Display for Expression { Expression::Conditional(cond, a, b) => write!(f, "{} ? {} : {}", cond, a, b), Expression::Var(ident) => write!(f, "Var({})", ident.name), Expression::Constant(val) => write!(f, "Constant({})", val), + Expression::FunctionCall(name, args) => write!(f, "{}({:?})", name, args), } } } @@ -821,6 +900,63 @@ fn option_ident_to_string(ident: &Option) -> String { } } +#[inline(always)] +fn log_trace(msg: &str, tokens: &mut VecDeque) { + log::trace!( + "{} {:?}", + msg, + tokens.iter().take(4).collect::>() + ); +} + +#[derive(Debug)] +pub enum ParserError { + UnexpectedToken, + NoTokens, + MalformedExpression, + IdentifierParsingError, + TrailingComma, +} + +impl fmt::Display for ParserError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Error for ParserError {} + +#[inline(always)] +fn expect_token_silent( + expected: TokenKind, + tokens: &mut VecDeque, +) -> Result { + let exp = discriminant(&expected); + let actual = discriminant(&tokens[0]); + + if actual != exp { + Err(ParserError::UnexpectedToken) + } else { + Ok(tokens.pop_front().unwrap()) + } +} + +#[inline(always)] +fn expect_token( + expected: TokenKind, + tokens: &mut VecDeque, +) -> Result { + let result = expect_token_silent(expected.clone(), tokens); + if let Err(_) = result { + log::error!( + "Syntax Error: Expected {:?}, got {:?}", + &expected, + &tokens[0] + ); + } + result +} + #[cfg(test)] mod tests { use super::*; @@ -840,16 +976,17 @@ mod tests { TokenKind::BraceClose, ]; - let function_expected = Function { + let function_expected = FunctionDeclaration { name: Identifier::new("main"), + params: vec![], return_type: "Int".to_owned(), - body: Block { + body: Some(Block { body: vec![BlockItem::Stmt(Statement::Return(Expression::Constant(7)))], - }, + }), }; let program_expected = Program { - body: function_expected, + body: vec![function_expected], }; assert_eq!(Program::parse(tokens), program_expected); @@ -870,15 +1007,19 @@ mod tests { TokenKind::BraceClose, ]); - let function_expected = Function { + let function_expected = FunctionDeclaration { name: Identifier::new("main"), + params: vec![], return_type: "Int".to_owned(), - body: Block { + body: Some(Block { body: vec![BlockItem::Stmt(Statement::Return(Expression::Constant(6)))], - }, + }), }; - assert_eq!(Function::parse(&mut tokens).unwrap(), function_expected); + assert_eq!( + FunctionDeclaration::parse(&mut tokens).unwrap(), + function_expected + ); assert!(tokens.is_empty()); } diff --git a/src/grammar.txt b/src/grammar.txt index 724855d..c93c5ce 100644 --- a/src/grammar.txt +++ b/src/grammar.txt @@ -2,11 +2,12 @@ AST Definition ========================== -program = Program(function_definition) -function_definition = Function(identifier name, block body) +program = Program(function_declaration*) +function_declaration = (identifier name, identifier* params, block? body) +variable_declaration = (identifier name, exp? init) +declaration = FunDecl(function_declaration) | VarDecl(variable_declaration) block = Block(block_item*) block_item = S(statement) | D(declaration) -declaration = Declaration(identifier name, exp? init) statement = Return(exp) | Expression(exp) | If(exp condition, statement then, statement? else) @@ -17,13 +18,14 @@ statement = Return(exp) | DoWhile(statement body, exp condition, identifier label) | For(for_init init, exp? condition, exp? post, statement body, identifier label) | Null -for_init = InitDecl(declaration) | InitExp(exp?) +for_init = InitDecl(variable_declaration) | InitExp(exp?) exp = Constant(int) | Var(identifier) | Unary(unary_operator, exp) | Binary(binary_operator, exp, exp) | Assignment(exp, exp) | Conditional(exp condition, exp, exp) + | FunctionCall(identifier, exp* args) unary_operator = Complement | Negation | Not binary_operator = Add | Subtract | Multiply | Divide | Remainder | And | Or | Equal | NotEqual | LessThan | LessEqualThan | GreaterThan | GreaterEqualThan @@ -32,12 +34,15 @@ binary_operator = Add | Subtract | Multiply | Divide | Remainder | And | Or Grammar ========================== - ::= + ::= {} + ::= | + ::= "int" ["=" ] ";" + ::= "int" "(" ")" ( | ";") + ::= "void" | "int" {"," "int" } ::= "int" "(" "void" ")" ::= "{" {} "}" ::= | - ::= "int" ["=" ] ";" - ::= | [] ";" + ::= | [] ";" ::= "return" ";" | ";" | "if" "(" ")" ["else" ] @@ -50,6 +55,8 @@ Grammar | ";" ::= | | "?" ":" ::= | | | "(" ")" + | "(" [] ")" + ::= {"," } ::= "-" | "~" | "!" ::= "-" | "+" | "*" | "/" | "%" | "&&" | "||" | "==" | "!=" | "<" | "<=" | ">" | ">=" diff --git a/src/lexer.rs b/src/lexer.rs index 1379c08..266fae8 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -40,6 +40,7 @@ static IF_RE: &str = r"if\b"; static ELSE_RE: &str = r"else\b"; static QUESTION_RE: &str = r"\?"; static COLON_RE: &str = r":"; +static COMMA_RE: &str = r"\,"; // NOTE: The tokenizer will try tokens in-order based on this list // It *must* be ordered longest-match first @@ -79,6 +80,7 @@ pub enum TokenKind { Else, Question, Colon, + Comma, Identifier(String), Constant(i64), } @@ -136,6 +138,7 @@ impl TokenKind { input if TokenKind::is_full_match(input, ELSE_RE) => Some(Self::Else), input if TokenKind::is_full_match(input, QUESTION_RE) => Some(Self::Question), input if TokenKind::is_full_match(input, COLON_RE) => Some(Self::Colon), + input if TokenKind::is_full_match(input, COMMA_RE) => Some(Self::Comma), input if TokenKind::is_full_match(input, CONSTANT_RE) => { Some(Self::Constant(input.parse::().unwrap())) } @@ -184,6 +187,7 @@ impl TokenKind { Self::Else => Regex::new(ELSE_RE), Self::Question => Regex::new(QUESTION_RE), Self::Colon => Regex::new(COLON_RE), + Self::Comma => Regex::new(COMMA_RE), } .unwrap() }