Skip to content

Commit

Permalink
llvm_ir: Unary & Binary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mrkajetanp committed Aug 6, 2024
1 parent 55f9e19 commit 778fdfd
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 30 deletions.
102 changes: 72 additions & 30 deletions src/llvm_ir.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use crate::ast;

use llvm_sys::prelude::{LLVMBuilderRef, LLVMModuleRef, LLVMValueRef};
use llvm_sys::LLVMValue;
use std::ffi::{CStr, CString};

use llvm_sys::prelude::LLVMContextRef;
use llvm_sys::core::{LLVMAddFunction, LLVMAppendBasicBlock, LLVMBuildRet, LLVMConstInt, LLVMContextCreate, LLVMContextDispose, LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMDisposeMessage, LLVMDisposeModule, LLVMFunctionType, LLVMInt32TypeInContext, LLVMModuleCreateWithNameInContext, LLVMPositionBuilderAtEnd, LLVMPrintModuleToString, LLVMVoidTypeInContext};
use llvm_sys::core::{LLVMAddFunction, LLVMAppendBasicBlock, LLVMBuildAdd, LLVMBuildMul, LLVMBuildNeg, LLVMBuildNot, LLVMBuildRet, LLVMBuildSDiv, LLVMBuildSRem, LLVMBuildSub, LLVMConstInt, LLVMContextCreate, LLVMContextDispose, LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMDisposeMessage, LLVMDisposeModule, LLVMFunctionType, LLVMInt32TypeInContext, LLVMModuleCreateWithNameInContext, LLVMPositionBuilderAtEnd, LLVMPrintModuleToString, LLVMVoidTypeInContext};

#[allow(dead_code)]
struct LLVMCodeGen {
ctx: LLVMContextRef,
module: LLVMModuleRef,
struct LLIrCtx {
ll: LLVMContextRef,
llmod: LLVMModuleRef,
builder: LLVMBuilderRef,
temp_var_id: u64,
}

impl LLVMCodeGen {
impl LLIrCtx {
pub fn new(name: &str) -> Self {
let name = CString::new(name).unwrap();
unsafe {
Expand All @@ -23,32 +25,39 @@ impl LLVMCodeGen {
);
let builder = LLVMCreateBuilderInContext(ctx);

LLVMCodeGen {
ctx,
module,
LLIrCtx {
ll: ctx,
llmod: module,
builder,
temp_var_id: 0
}
}
}

pub fn temp_var(&mut self) -> CString {
let id = self.temp_var_id;
self.temp_var_id += 1;
CString::new(format!("tmp.{}", id)).unwrap()
}
}

impl Drop for LLVMCodeGen {
impl Drop for LLIrCtx {
fn drop(&mut self) {
unsafe {
LLVMDisposeBuilder(self.builder);
LLVMDisposeModule(self.module);
LLVMContextDispose(self.ctx);
LLVMDisposeModule(self.llmod);
LLVMContextDispose(self.ll);
}
}
}


impl<'a> ast::Program {
pub fn to_llvm(self, name: &str) -> String {
let codegen = LLVMCodeGen::new(name);
let mut codegen = LLIrCtx::new(name);
unsafe {
self.body.to_llvm(&codegen);
let code = LLVMPrintModuleToString(codegen.module);
self.body.to_llvm(&mut codegen);
let code = LLVMPrintModuleToString(codegen.llmod);
let result = CStr::from_ptr(code).to_string_lossy().into_owned();
LLVMDisposeMessage(code as *mut _);
result
Expand All @@ -57,15 +66,15 @@ impl<'a> ast::Program {
}

impl<'a> ast::Function {
unsafe fn to_llvm(self, llvm: &LLVMCodeGen) -> LLVMValueRef {
unsafe fn to_llvm(self, llvm: &mut LLIrCtx) -> LLVMValueRef {
let name = CString::new(self.name).unwrap();
let fn_type = LLVMInt32TypeInContext(llvm.ctx);
let param_types = [LLVMVoidTypeInContext(llvm.ctx)].as_mut_ptr();
let fn_type = LLVMInt32TypeInContext(llvm.ll);
let param_types = [LLVMVoidTypeInContext(llvm.ll)].as_mut_ptr();
let fn_type = LLVMFunctionType(
fn_type, param_types, 0, 0
);
let func = LLVMAddFunction(
llvm.module, name.as_ptr(), fn_type
llvm.llmod, name.as_ptr(), fn_type
);
let block_name = CString::new("entry").unwrap();
let block = LLVMAppendBasicBlock(func, block_name.as_ptr());
Expand All @@ -76,7 +85,7 @@ impl<'a> ast::Function {
}

impl<'a> ast::Statement {
unsafe fn to_llvm(self, llvm: &LLVMCodeGen) {
unsafe fn to_llvm(self, llvm: &mut LLIrCtx) {
match self {
ast::Statement::Return(expr) => {
let value = expr.to_llvm(llvm);
Expand All @@ -87,22 +96,55 @@ impl<'a> ast::Statement {
}

impl<'a> ast::Expression {
unsafe fn to_llvm(self, llvm: &LLVMCodeGen) -> LLVMValueRef {
unsafe fn to_llvm(self, llvm: &mut LLIrCtx) -> LLVMValueRef {
match self {
ast::Expression::Constant(ref val) => {
LLVMConstInt(
LLVMInt32TypeInContext(llvm.ctx), *val as u64, 0
LLVMInt32TypeInContext(llvm.ll), *val as u64, 0
)
},
// ast::Expression::Unary(op, expr) => {
// todo!()
// },
// ast::Expression::Binary(
// op, left, right
// ) => {
// todo!()
// },
_ => todo!()
ast::Expression::Unary(op, expr) => {
let op = op.to_llvm();
let val = expr.to_llvm(llvm);
let name = CString::new("negtmp").unwrap();
op(llvm.builder, val, name.as_ptr())
},
ast::Expression::Binary(
op, left, right
) => {
let op = op.to_llvm();
let left = left.to_llvm(llvm);
let right = right.to_llvm(llvm);
let dst = llvm.temp_var();
op(llvm.builder, left, right, dst.as_ptr())
},
}
}
}

type LLVMUnaryOpFn = unsafe extern "C"
fn(LLVMBuilderRef, LLVMValueRef, *const i8) -> *mut LLVMValue;

impl<'a> ast::UnaryOperator {
unsafe fn to_llvm(self) -> LLVMUnaryOpFn {
match self {
ast::UnaryOperator::Negation => LLVMBuildNeg,
ast::UnaryOperator::Complement => LLVMBuildNot,
}
}
}

type LLVMBinaryOpFn = unsafe extern "C"
fn(LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *const i8) -> *mut LLVMValue;

impl<'a> ast::BinaryOperator {
unsafe fn to_llvm(self) -> LLVMBinaryOpFn {
match self {
ast::BinaryOperator::Add => LLVMBuildAdd,
ast::BinaryOperator::Subtract => LLVMBuildSub,
ast::BinaryOperator::Multiply => LLVMBuildMul,
ast::BinaryOperator::Divide => LLVMBuildSDiv,
ast::BinaryOperator::Remainder => LLVMBuildSRem,
}
}
}
106 changes: 106 additions & 0 deletions tests/llvm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
extern crate c_compiler;

use c_compiler::*;
use serial_test::serial;

static BASIC_SAMPLE: &str = "samples/basic.c";
static UNARY_SAMPLE: &str = "samples/unary.c";
static BINARY_SAMPLE: &str = "samples/binary.c";
static DIV_SAMPLE: &str = "samples/div.c";

// ------------ basic functions ---------
//
#[test]
#[serial]
fn basic_ir() {
let driver = Driver::new(BASIC_SAMPLE);
driver.compile(CompileStage::IR, true);
}

#[test]
#[serial]
fn basic_codegen() {
let driver = Driver::new(BASIC_SAMPLE);
driver.compile(CompileStage::Codegen, true);
}

#[test]
#[serial]
fn basic_full() {
let driver = Driver::new(BASIC_SAMPLE);
driver.compile(CompileStage::Full, true);
driver.clean_binary().unwrap();
}

// ------------ unary operators ---------
//
#[test]
#[serial]
fn unary_ir() {
let driver = Driver::new(UNARY_SAMPLE);
driver.compile(CompileStage::IR, true);
}

#[test]
#[serial]
fn unary_codegen() {
let driver = Driver::new(UNARY_SAMPLE);
driver.compile(CompileStage::Codegen, true);
}

#[test]
#[serial]
fn unary_full() {
let driver = Driver::new(UNARY_SAMPLE);
driver.compile(CompileStage::Full, true);
driver.clean_binary().unwrap();
}

// ------------ binary operators ---------

#[test]
#[serial]
fn binary_ir() {
let driver = Driver::new(BINARY_SAMPLE);
driver.compile(CompileStage::IR, true);
}

#[test]
#[serial]
fn binary_codegen() {
let driver = Driver::new(BINARY_SAMPLE);
driver.compile(CompileStage::Codegen, true);
}

#[test]
#[serial]
fn binary_full() {
let driver = Driver::new(BINARY_SAMPLE);
driver.compile(CompileStage::Full, true);
driver.clean_binary().unwrap();
}

// ------------ division operators ---------

#[test]
#[serial]
fn div_ir() {
let driver = Driver::new(DIV_SAMPLE);
driver.compile(CompileStage::IR, true);
}

#[test]
#[serial]
fn div_codegen() {
let driver = Driver::new(DIV_SAMPLE);
driver.compile(CompileStage::Codegen, true);
}

#[test]
#[serial]
fn div_full() {
let driver = Driver::new(DIV_SAMPLE);
driver.compile(CompileStage::Full, true);
driver.clean_binary().unwrap();
}

0 comments on commit 778fdfd

Please sign in to comment.