From 778fdfdf73cdb0d09f7799a275c33a9ed3a90da2 Mon Sep 17 00:00:00 2001 From: Kajetan Puchalski Date: Tue, 6 Aug 2024 23:50:16 +0100 Subject: [PATCH] llvm_ir: Unary & Binary operators --- src/llvm_ir.rs | 102 +++++++++++++++++++++++++++++++++-------------- tests/llvm.rs | 106 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 30 deletions(-) create mode 100644 tests/llvm.rs diff --git a/src/llvm_ir.rs b/src/llvm_ir.rs index f40c887..77ad986 100644 --- a/src/llvm_ir.rs +++ b/src/llvm_ir.rs @@ -1,19 +1,21 @@ use crate::ast; use llvm_sys::prelude::{LLVMBuilderRef, LLVMModuleRef, LLVMValueRef}; +use llvm_sys::LLVMValue; use std::ffi::{CStr, CString}; use llvm_sys::prelude::LLVMContextRef; -use llvm_sys::core::{LLVMAddFunction, LLVMAppendBasicBlock, LLVMBuildRet, LLVMConstInt, LLVMContextCreate, LLVMContextDispose, LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMDisposeMessage, LLVMDisposeModule, LLVMFunctionType, LLVMInt32TypeInContext, LLVMModuleCreateWithNameInContext, LLVMPositionBuilderAtEnd, LLVMPrintModuleToString, LLVMVoidTypeInContext}; +use llvm_sys::core::{LLVMAddFunction, LLVMAppendBasicBlock, LLVMBuildAdd, LLVMBuildMul, LLVMBuildNeg, LLVMBuildNot, LLVMBuildRet, LLVMBuildSDiv, LLVMBuildSRem, LLVMBuildSub, LLVMConstInt, LLVMContextCreate, LLVMContextDispose, LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMDisposeMessage, LLVMDisposeModule, LLVMFunctionType, LLVMInt32TypeInContext, LLVMModuleCreateWithNameInContext, LLVMPositionBuilderAtEnd, LLVMPrintModuleToString, LLVMVoidTypeInContext}; #[allow(dead_code)] -struct LLVMCodeGen { - ctx: LLVMContextRef, - module: LLVMModuleRef, +struct LLIrCtx { + ll: LLVMContextRef, + llmod: LLVMModuleRef, builder: LLVMBuilderRef, + temp_var_id: u64, } -impl LLVMCodeGen { +impl LLIrCtx { pub fn new(name: &str) -> Self { let name = CString::new(name).unwrap(); unsafe { @@ -23,21 +25,28 @@ impl LLVMCodeGen { ); let builder = LLVMCreateBuilderInContext(ctx); - LLVMCodeGen { - ctx, - module, + LLIrCtx { + ll: ctx, + llmod: module, builder, + temp_var_id: 0 } } } + + pub fn temp_var(&mut self) -> CString { + let id = self.temp_var_id; + self.temp_var_id += 1; + CString::new(format!("tmp.{}", id)).unwrap() + } } -impl Drop for LLVMCodeGen { +impl Drop for LLIrCtx { fn drop(&mut self) { unsafe { LLVMDisposeBuilder(self.builder); - LLVMDisposeModule(self.module); - LLVMContextDispose(self.ctx); + LLVMDisposeModule(self.llmod); + LLVMContextDispose(self.ll); } } } @@ -45,10 +54,10 @@ impl Drop for LLVMCodeGen { impl<'a> ast::Program { pub fn to_llvm(self, name: &str) -> String { - let codegen = LLVMCodeGen::new(name); + let mut codegen = LLIrCtx::new(name); unsafe { - self.body.to_llvm(&codegen); - let code = LLVMPrintModuleToString(codegen.module); + self.body.to_llvm(&mut codegen); + let code = LLVMPrintModuleToString(codegen.llmod); let result = CStr::from_ptr(code).to_string_lossy().into_owned(); LLVMDisposeMessage(code as *mut _); result @@ -57,15 +66,15 @@ impl<'a> ast::Program { } impl<'a> ast::Function { - unsafe fn to_llvm(self, llvm: &LLVMCodeGen) -> LLVMValueRef { + unsafe fn to_llvm(self, llvm: &mut LLIrCtx) -> LLVMValueRef { let name = CString::new(self.name).unwrap(); - let fn_type = LLVMInt32TypeInContext(llvm.ctx); - let param_types = [LLVMVoidTypeInContext(llvm.ctx)].as_mut_ptr(); + let fn_type = LLVMInt32TypeInContext(llvm.ll); + let param_types = [LLVMVoidTypeInContext(llvm.ll)].as_mut_ptr(); let fn_type = LLVMFunctionType( fn_type, param_types, 0, 0 ); let func = LLVMAddFunction( - llvm.module, name.as_ptr(), fn_type + llvm.llmod, name.as_ptr(), fn_type ); let block_name = CString::new("entry").unwrap(); let block = LLVMAppendBasicBlock(func, block_name.as_ptr()); @@ -76,7 +85,7 @@ impl<'a> ast::Function { } impl<'a> ast::Statement { - unsafe fn to_llvm(self, llvm: &LLVMCodeGen) { + unsafe fn to_llvm(self, llvm: &mut LLIrCtx) { match self { ast::Statement::Return(expr) => { let value = expr.to_llvm(llvm); @@ -87,22 +96,55 @@ impl<'a> ast::Statement { } impl<'a> ast::Expression { - unsafe fn to_llvm(self, llvm: &LLVMCodeGen) -> LLVMValueRef { + unsafe fn to_llvm(self, llvm: &mut LLIrCtx) -> LLVMValueRef { match self { ast::Expression::Constant(ref val) => { LLVMConstInt( - LLVMInt32TypeInContext(llvm.ctx), *val as u64, 0 + LLVMInt32TypeInContext(llvm.ll), *val as u64, 0 ) }, - // ast::Expression::Unary(op, expr) => { - // todo!() - // }, - // ast::Expression::Binary( - // op, left, right - // ) => { - // todo!() - // }, - _ => todo!() + ast::Expression::Unary(op, expr) => { + let op = op.to_llvm(); + let val = expr.to_llvm(llvm); + let name = CString::new("negtmp").unwrap(); + op(llvm.builder, val, name.as_ptr()) + }, + ast::Expression::Binary( + op, left, right + ) => { + let op = op.to_llvm(); + let left = left.to_llvm(llvm); + let right = right.to_llvm(llvm); + let dst = llvm.temp_var(); + op(llvm.builder, left, right, dst.as_ptr()) + }, + } + } +} + +type LLVMUnaryOpFn = unsafe extern "C" + fn(LLVMBuilderRef, LLVMValueRef, *const i8) -> *mut LLVMValue; + +impl<'a> ast::UnaryOperator { + unsafe fn to_llvm(self) -> LLVMUnaryOpFn { + match self { + ast::UnaryOperator::Negation => LLVMBuildNeg, + ast::UnaryOperator::Complement => LLVMBuildNot, + } + } +} + +type LLVMBinaryOpFn = unsafe extern "C" + fn(LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *const i8) -> *mut LLVMValue; + +impl<'a> ast::BinaryOperator { + unsafe fn to_llvm(self) -> LLVMBinaryOpFn { + match self { + ast::BinaryOperator::Add => LLVMBuildAdd, + ast::BinaryOperator::Subtract => LLVMBuildSub, + ast::BinaryOperator::Multiply => LLVMBuildMul, + ast::BinaryOperator::Divide => LLVMBuildSDiv, + ast::BinaryOperator::Remainder => LLVMBuildSRem, } } } diff --git a/tests/llvm.rs b/tests/llvm.rs new file mode 100644 index 0000000..d81925b --- /dev/null +++ b/tests/llvm.rs @@ -0,0 +1,106 @@ +extern crate c_compiler; + +use c_compiler::*; +use serial_test::serial; + +static BASIC_SAMPLE: &str = "samples/basic.c"; +static UNARY_SAMPLE: &str = "samples/unary.c"; +static BINARY_SAMPLE: &str = "samples/binary.c"; +static DIV_SAMPLE: &str = "samples/div.c"; + +// ------------ basic functions --------- +// +#[test] +#[serial] +fn basic_ir() { + let driver = Driver::new(BASIC_SAMPLE); + driver.compile(CompileStage::IR, true); +} + +#[test] +#[serial] +fn basic_codegen() { + let driver = Driver::new(BASIC_SAMPLE); + driver.compile(CompileStage::Codegen, true); +} + +#[test] +#[serial] +fn basic_full() { + let driver = Driver::new(BASIC_SAMPLE); + driver.compile(CompileStage::Full, true); + driver.clean_binary().unwrap(); +} + +// ------------ unary operators --------- +// +#[test] +#[serial] +fn unary_ir() { + let driver = Driver::new(UNARY_SAMPLE); + driver.compile(CompileStage::IR, true); +} + +#[test] +#[serial] +fn unary_codegen() { + let driver = Driver::new(UNARY_SAMPLE); + driver.compile(CompileStage::Codegen, true); +} + +#[test] +#[serial] +fn unary_full() { + let driver = Driver::new(UNARY_SAMPLE); + driver.compile(CompileStage::Full, true); + driver.clean_binary().unwrap(); +} + +// ------------ binary operators --------- + +#[test] +#[serial] +fn binary_ir() { + let driver = Driver::new(BINARY_SAMPLE); + driver.compile(CompileStage::IR, true); +} + +#[test] +#[serial] +fn binary_codegen() { + let driver = Driver::new(BINARY_SAMPLE); + driver.compile(CompileStage::Codegen, true); +} + +#[test] +#[serial] +fn binary_full() { + let driver = Driver::new(BINARY_SAMPLE); + driver.compile(CompileStage::Full, true); + driver.clean_binary().unwrap(); +} + +// ------------ division operators --------- + +#[test] +#[serial] +fn div_ir() { + let driver = Driver::new(DIV_SAMPLE); + driver.compile(CompileStage::IR, true); +} + +#[test] +#[serial] +fn div_codegen() { + let driver = Driver::new(DIV_SAMPLE); + driver.compile(CompileStage::Codegen, true); +} + +#[test] +#[serial] +fn div_full() { + let driver = Driver::new(DIV_SAMPLE); + driver.compile(CompileStage::Full, true); + driver.clean_binary().unwrap(); +} +