From c744c7086a56ac5ec5cf45abc13466143abbd162 Mon Sep 17 00:00:00 2001 From: Kajetan Puchalski Date: Mon, 2 Sep 2024 23:54:04 +0100 Subject: [PATCH] codegen: Panic-free codegen error handling --- src/codegen.rs | 81 +++++++++++++++++++++++++++----------------------- src/lib.rs | 26 ++++++++++------ tests/basic.rs | 4 +-- 3 files changed, 63 insertions(+), 48 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 70525f9..4763a3e 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -9,6 +9,10 @@ use std::fmt; pub enum CodegenError { #[error("Pseudo-operand in emit stage: {0}")] PseudoOperandInEmit(String), + #[error("Unexpected operator")] + UnexpectedOperator, + #[error("Unexpected IR instruction")] + UnexpectedIrInstruction, } type CodegenResult = Result; @@ -20,14 +24,14 @@ pub struct Program { } impl Program { - pub fn codegen(program: ir::Program) -> Program { - Program { + pub fn codegen(program: ir::Program) -> CodegenResult { + Ok(Program { body: program .body .into_iter() .map(|f| Function::codegen(f)) - .collect(), - } + .collect::>>()?, + }) } pub fn emit(&self) -> CodegenResult { @@ -65,7 +69,7 @@ pub struct Function { } impl Function { - pub fn codegen(function: ir::Function) -> Self { + pub fn codegen(function: ir::Function) -> CodegenResult { let params_len = function.params.len(); let prologue_stack_size = if params_len > 6 { (params_len - 6 + 1) * 8 @@ -105,7 +109,7 @@ impl Function { function .instructions .into_iter() - .flat_map(|instr| Instruction::codegen(instr)), + .flat_map(|instr| Instruction::codegen(instr).unwrap()), ) .collect(); @@ -118,13 +122,15 @@ impl Function { .join("\n") ); - Self { + let result = Self { name: Identifier::codegen(function.name), instructions, stack_pos: -(prologue_stack_size as i64), } .replace_pseudo() - .fixup() + .fixup(); + + Ok(result) } fn replace_pseudo(mut self) -> Self { @@ -200,7 +206,7 @@ pub enum Instruction { } impl Instruction { - pub fn codegen(instruction: ir::Instruction) -> Vec { + pub fn codegen(instruction: ir::Instruction) -> CodegenResult> { let mut instructions = vec![]; match instruction { @@ -223,14 +229,14 @@ impl Instruction { let src = Operand::from_val(src); let dst = Operand::from_val(dst); instructions.push(Self::Mov(dst.clone(), src)); - instructions.push(Self::Unary(UnaryOperator::codegen(op), dst.clone())); + instructions.push(Self::Unary(UnaryOperator::codegen(op)?, dst.clone())); } ir::Instruction::Binary(op, src1, src2, dst) if op.is_relational() => { let src1 = Operand::from_val(src1); let src2 = Operand::from_val(src2); let dst = Operand::from_val(dst); - let cc = CondCode::from_op(op); + let cc = CondCode::from_op(op)?; instructions.push(Self::Cmp(src1, src2)); instructions.push(Self::Mov(dst.clone(), Operand::Immediate(0))); instructions.push(Self::SetCC(cc, dst)); @@ -241,7 +247,7 @@ impl Instruction { let src2 = Operand::from_val(src2); let dst = Operand::from_val(dst); instructions.push(Self::Mov(dst.clone(), src1)); - instructions.push(Self::Binary(BinaryOperator::codegen(op), dst, src2)); + instructions.push(Self::Binary(BinaryOperator::codegen(op)?, dst, src2)); } ir::Instruction::Binary(op, src1, src2, dst) if op.is_divide() || op.is_remainder() => { @@ -338,10 +344,10 @@ impl Instruction { let return_reg = Operand::Reg(Register::AX); instructions.push(Instruction::Mov(asm_dst, return_reg)); } - _ => panic!("Unexpected IR instruction in codegen"), + _ => return Err(CodegenError::UnexpectedIrInstruction), } - instructions + Ok(instructions) } pub fn replace_pseudo( @@ -488,12 +494,12 @@ pub enum BinaryOperator { } impl BinaryOperator { - pub fn codegen(operator: ir::BinaryOperator) -> Self { + pub fn codegen(operator: ir::BinaryOperator) -> CodegenResult { match operator { - ir::BinaryOperator::Add => Self::Add, - ir::BinaryOperator::Subtract => Self::Sub, - ir::BinaryOperator::Multiply => Self::Mult, - _ => panic!("Unsupported binary operator"), + ir::BinaryOperator::Add => Ok(Self::Add), + ir::BinaryOperator::Subtract => Ok(Self::Sub), + ir::BinaryOperator::Multiply => Ok(Self::Mult), + _ => Err(CodegenError::UnexpectedOperator), } } @@ -521,11 +527,11 @@ pub enum UnaryOperator { } impl UnaryOperator { - pub fn codegen(operator: ir::UnaryOperator) -> Self { + pub fn codegen(operator: ir::UnaryOperator) -> CodegenResult { match operator { - ir::UnaryOperator::Complement => Self::Not, - ir::UnaryOperator::Negation => Self::Neg, - _ => panic!("Codegen: Unexpected unary operator {:?}", operator), + ir::UnaryOperator::Complement => Ok(Self::Not), + ir::UnaryOperator::Negation => Ok(Self::Neg), + _ => Err(CodegenError::UnexpectedOperator), } } @@ -701,15 +707,15 @@ pub enum CondCode { } impl CondCode { - pub fn from_op(op: ir::BinaryOperator) -> Self { + pub fn from_op(op: ir::BinaryOperator) -> CodegenResult { match op { - ir::BinaryOperator::Equal => Self::E, - ir::BinaryOperator::NotEqual => Self::NE, - ir::BinaryOperator::GreaterThan => Self::G, - ir::BinaryOperator::GreaterEqualThan => Self::GE, - ir::BinaryOperator::LessThan => Self::L, - ir::BinaryOperator::LessEqualThan => Self::LE, - _ => panic!("Codegen: Unexpected Binary Operator in CondCode"), + ir::BinaryOperator::Equal => Ok(Self::E), + ir::BinaryOperator::NotEqual => Ok(Self::NE), + ir::BinaryOperator::GreaterThan => Ok(Self::G), + ir::BinaryOperator::GreaterEqualThan => Ok(Self::GE), + ir::BinaryOperator::LessThan => Ok(Self::L), + ir::BinaryOperator::LessEqualThan => Ok(Self::LE), + _ => Err(CodegenError::UnexpectedOperator), } } @@ -876,9 +882,9 @@ mod tests { }; let expected = Program { - body: vec![Function::codegen(ir_program.body[0].clone())], + body: vec![Function::codegen(ir_program.body[0].clone()).unwrap()], }; - let actual = Program::codegen(ir_program); + let actual = Program::codegen(ir_program).unwrap(); assert_eq!(actual, expected); } @@ -912,7 +918,7 @@ mod tests { stack_pos: -16, }; - assert_eq!(actual, expected); + assert_eq!(actual.unwrap(), expected); } #[test] @@ -925,6 +931,7 @@ mod tests { ir::Val::Constant(5), ir::Val::Var(ir::Identifier::new("x")), )) + .unwrap() .into_iter() .map(|instr| instr.replace_pseudo(&mut stack_pos, &mut stack_addrs)) .collect(); @@ -942,18 +949,18 @@ mod tests { Instruction::Mov(Operand::Reg(Register::AX), Operand::Immediate(5)), Instruction::Ret, ]; - assert_eq!(actual, expected); + assert_eq!(actual.unwrap(), expected); } #[test] fn unary() { assert_eq!( UnaryOperator::Neg, - UnaryOperator::codegen(ir::UnaryOperator::Negation) + UnaryOperator::codegen(ir::UnaryOperator::Negation).unwrap() ); assert_eq!( UnaryOperator::Not, - UnaryOperator::codegen(ir::UnaryOperator::Complement) + UnaryOperator::codegen(ir::UnaryOperator::Complement).unwrap() ); } diff --git a/src/lib.rs b/src/lib.rs index 2f76818..f72889b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,12 +32,16 @@ pub enum ErrorKind { SemanticError, #[error("Type Checking Failed")] TypeCheckError, + #[error("Codegen Failed")] + CodegenError, #[error("Asm Emission Failed")] AsmEmitError, #[error("IO Error")] IOError, } +type CompileResult = Result; + #[derive(PartialEq, EnumIs, Clone, Copy)] pub enum CompileStage { Lex, @@ -98,10 +102,10 @@ impl Driver { if _llvm { asm_path = self.llvm_asm_path(ast, stage); } else { - asm_path = self.asm_path(ast, stage); + asm_path = self.asm_path(ast, stage)?; } } else { - asm_path = self.asm_path(ast, stage) + asm_path = self.asm_path(ast, stage)?; } } @@ -113,22 +117,26 @@ impl Driver { Ok(()) } - pub fn asm_path(&self, ast: ast::Program, stage: CompileStage) -> Option { + pub fn asm_path( + &self, + ast: ast::Program, + stage: CompileStage, + ) -> CompileResult> { let ir = self.generate_ir(ast); log::debug!("Generated IR:\n{}\n", &ir); if stage.is_ir() { - return None; + return Ok(None); } - let code = self.codegen(ir); + let code = self.codegen(ir)?; log::trace!("Codegen:\n{}\n", &code); if stage.is_codegen() { - return None; + return Ok(None); } - Some(self.emit(code).unwrap()) + Ok(Some(self.emit(code).unwrap())) } pub fn preprocess(&self) -> String { @@ -155,8 +163,8 @@ impl Driver { ir::Program::generate(ast, &mut ir_ctx) } - fn codegen(&self, ir: ir::Program) -> codegen::Program { - codegen::Program::codegen(ir) + fn codegen(&self, ir: ir::Program) -> CompileResult { + codegen::Program::codegen(ir).map_err(|_| ErrorKind::CodegenError) } fn emit(&self, code: codegen::Program) -> Result { diff --git a/tests/basic.rs b/tests/basic.rs index 8a7d21d..2f717b6 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -29,7 +29,7 @@ fn test_basic_lex() { ]; let tokens = driver.lex(source); - assert_eq!(tokens, tokens_expected); + assert_eq!(tokens.unwrap(), tokens_expected); } #[test] @@ -86,7 +86,7 @@ fn test_unary_lex() { ]; let tokens = driver.lex(source); - assert_eq!(tokens, tokens_expected); + assert_eq!(tokens.unwrap(), tokens_expected); } #[test]