Skip to content

Commit

Permalink
Unary and binary ops working :D
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanBrouwer committed Nov 14, 2023
1 parent 4dbd707 commit 00c89f0
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 128 deletions.
33 changes: 29 additions & 4 deletions compiler/src/passes/validate/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,41 @@ pub enum TypeError {
},
#[error("The program doesn't have a main function.")]
NoMain,

#[error("Types did not match.")]
MismatchedFnReturn {
expect: String,
got: String,

#[label = "Expected this function to be of type: `{expect}`"]
//TODO would like this span to be return type if present
#[label = "Expected this function to return: `{expect}`"]
span_expected: (usize, usize),

#[label = "but got this type: `{got}`"]
#[label = "But got this type: `{got}`"]
span_got: (usize, usize),
},
#[error("Types did not match.")]
OperandExpect {
expect: String,
got: String,
op: String,

//TODO would like this span to be operator
#[label = "Arguments of {op} are of type: `{expect}`"]
span_op: (usize, usize),
#[label = "But got this type: `{got}`"]
span_arg: (usize, usize),
},
#[error("Types did not match.")]
OperandEqual {
lhs: String,
rhs: String,
op: String,

//TODO would like this span to be operator
#[label = "Arguments of {op} should be of equal types."]
span_op: (usize, usize),
#[label = "Type: `{lhs}`"]
span_lhs: (usize, usize),
#[label = "Type: `{rhs}`"]
span_rhs: (usize, usize),
},
}
233 changes: 172 additions & 61 deletions compiler/src/passes/validate/generate_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::passes::parse::{Lit, Meta, Span};
use crate::passes::parse::types::Type;
use crate::passes::parse::{BinaryOp, Lit, Meta, Span, UnaryOp};
use crate::passes::validate::error::TypeError;
use crate::passes::validate::error::TypeError::MismatchedFnReturn;
use crate::passes::validate::uncover_globals::{uncover_globals, Env, EnvEntry};
use crate::passes::validate::uniquify::PrgUniquified;
use crate::passes::validate::{
type_to_index, CMeta, DefConstrained, DefUniquified, ExprConstrained, ExprUniquified,
PrgConstrained,
CMeta, DefConstrained, DefUniquified, ExprConstrained, ExprUniquified, PrgConstrained,
};
use crate::utils::gen_sym::UniqueSym;
use crate::utils::union_find::{UnionFind, UnionIndex};
Expand Down Expand Up @@ -87,32 +86,37 @@ fn constrain_def<'p>(
typ,
bdy,
} => {
// Put function parameters in scope.
scope.extend(params.iter().map(|p| {
(
p.sym.inner,
EnvEntry::Type {
mutable: p.mutable,
typ: type_to_index(p.typ.clone(), uf),
typ: uf.type_to_index(p.typ.clone()),
},
)
}));

let return_index = type_to_index(typ.clone(), uf);
// Add return type to env and keep it for error handling.
let return_index = uf.type_to_index(typ.clone());
let mut env = Env {
uf,
scope,
return_type: return_index,
};

// Constrain body of function.
let bdy = constrain_expr(bdy, &mut env)?;

uf.try_union_by(return_index, bdy.meta.index, combine_partial_types)
.map_err(|_| MismatchedFnReturn {
expect: format!("{typ}"),
got: format!("{}", "bananas"),
span_expected: (0, 0),
span_got: (0, 0),
})?;
// Return error if function body a type differs from its return type.
uf.expect_equal(return_index, bdy.meta.index, |r, b| {
TypeError::MismatchedFnReturn {
expect: r,
got: b,
span_expected: sym.meta,
span_got: bdy.meta.span,
}
})?;

DefConstrained::Fn {
sym,
Expand All @@ -131,6 +135,8 @@ fn constrain_expr<'p>(
expr: Meta<Span, ExprUniquified<'p>>,
env: &mut Env<'_, 'p>,
) -> Result<Meta<CMeta, ExprConstrained<'p>>, TypeError> {
let span = expr.meta;

Ok(match expr.inner {
ExprUniquified::Lit { val } => {
let typ = match val {
Expand All @@ -140,56 +146,106 @@ fn constrain_expr<'p>(
};
let index = env.uf.add(typ);
Meta {
meta: CMeta {
span: expr.meta,
index,
},
meta: CMeta { span, index },
inner: ExprConstrained::Lit { val },
}
}
ExprUniquified::Var { sym } => {
let EnvEntry::Type { typ, .. } = env.scope[&sym.inner] else {
panic!();
};
Meta {
meta: CMeta { span, index: typ },
inner: ExprConstrained::Var { sym },
}
}
ExprUniquified::UnaryOp { op, expr } => {
let typ = match op {
UnaryOp::Neg => Type::I64,
UnaryOp::Not => Type::Bool,
};
let expr = constrain_expr(*expr, env)?;

env.uf.expect_type(expr.meta.index, typ, |got, expect| {
TypeError::OperandExpect {
expect,
got,
op: op.to_string(),
span_op: span,
span_arg: expr.meta.span,
}
})?;

Meta {
meta: CMeta {
span: expr.meta,
index: typ,
span,
index: expr.meta.index,
},
inner: ExprConstrained::UnaryOp {
op,
expr: Box::new(expr),
},
}
}
ExprUniquified::BinaryOp { op, exprs: [lhs, rhs] } => {
// input: None = Any but equal, Some = expect this
// output: None = Same as input, Some = this
let (input, output) = match op {
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => (Some(PartialType::Int), None),
BinaryOp::LAnd | BinaryOp::LOr | BinaryOp::Xor => (Some(PartialType::Bool), None),
BinaryOp::GT | BinaryOp::GE | BinaryOp::LE | BinaryOp::LT => (Some(PartialType::Int), Some(PartialType::Bool)),
BinaryOp::EQ | BinaryOp::NE => (None, Some(PartialType::Bool)),
};

let e1 = constrain_expr(*lhs, env)?;
let e2 = constrain_expr(*rhs, env)?;

// Check inputs satisfy constraints
if let Some(input) = input {
let mut check = |expr: &Meta<CMeta, ExprConstrained<'p>>| {
env.uf.expect_partial_type(expr.meta.index, input.clone(), |got, expect| TypeError::OperandExpect {
expect,
got,
op: op.to_string(),
span_op: span,
span_arg: expr.meta.span,
})
};

check(&e1)?;
check(&e2)?;
}

// Check inputs equal
let input_index = env.uf.expect_equal(e1.meta.index, e2.meta.index, |lhs, rhs| TypeError::OperandEqual {
lhs,
rhs,
op: op.to_string(),
span_op: span,
span_lhs: e1.meta.span,
span_rhs: e2.meta.span,
})?;

// Generate output index
let output_index = match output {
None => input_index,
Some(e) => {
env.uf.add(e)
}
};

Meta {
meta: CMeta {
span,
index: output_index,
},
inner: ExprConstrained::BinaryOp {
op,
exprs: [e1, e2].map(Box::new),
},
inner: ExprConstrained::Var { sym },
}
}
ExprUniquified::UnaryOp { op, expr } => todo!(),
ExprUniquified::BinaryOp { op, exprs } => todo!(),
// ExprUniquified::Prim { op, args } if args.len() == 2 => {
// let (pt, lhs, rhs) = match op {
// Op::Plus => (PartialType::Int, PartialType::Int, PartialType::Int),
// Op::Minus => PartialType::Int,
// Op::Mul => PartialType::Int,
// Op::Div => PartialType::Int,
// Op::Mod => PartialType::Int,
// Op::LAnd => todo!(),
// Op::LOr => todo!(),
// Op::Xor => todo!(),
// Op::GT => PartialType::Int,
// Op::GE => PartialType::Int,
// Op::EQ => todo!(),
// Op::LE => PartialType::Int,
// Op::LT => PartialType::Int,
// Op::NE => todo!(),
// Op::Read | Op::Print | Op::Not => unreachable!(),
// };
//
// let index = env.uf.add(pt);
//
// Meta {
// meta: CMeta{ span: expr.meta, index },
// inner: ExprConstrained::Prim { op, args: args.into_iter().map(|arg| match arg {
//
// })},
// }
// },
ExprUniquified::Let { sym, bnd, bdy, .. } => todo!(),
ExprUniquified::Let { .. } => todo!(),
ExprUniquified::If { .. } => todo!(),
ExprUniquified::Apply { .. } => todo!(),
ExprUniquified::Loop { .. } => todo!(),
Expand All @@ -205,11 +261,10 @@ fn constrain_expr<'p>(
})
}

// uf: &mut UnionFind<PartialType<'p>>
fn combine_partial_types<'p>(
a: PartialType<'p>,
b: PartialType<'p>,
uf: &mut UnionFind<PartialType<'p>>
uf: &mut UnionFind<PartialType<'p>>,
) -> Result<PartialType<'p>, ()> {
let typ = match (a, b) {
(PartialType::I64, PartialType::I64 | PartialType::Int) => PartialType::I64,
Expand All @@ -235,22 +290,78 @@ fn combine_partial_types<'p>(
},
) => {
if params_a.len() != params_b.len() {
return Err(())
return Err(());
}

let params = params_a.into_iter().zip(params_b).map(|(param_a, param_b)| {
uf.try_union_by(param_a, param_b, combine_partial_types)
}).collect::<Result<_,_>>()?;
let params = params_a
.into_iter()
.zip(params_b)
.map(|(param_a, param_b)| uf.try_union_by(param_a, param_b, combine_partial_types))
.collect::<Result<_, _>>()?;

let typ = uf.try_union_by(typ_a, typ_b, combine_partial_types)?;

PartialType::Fn {
params,
typ,
}
PartialType::Fn { params, typ }
}
_ => return Err(()),
};

Ok(typ)
}

impl<'p> UnionFind<PartialType<'p>> {
pub fn expect_equal(
&mut self,
a: UnionIndex,
b: UnionIndex,
map_err: impl FnOnce(String, String) -> TypeError,
) -> Result<UnionIndex, TypeError> {
self.try_union_by(a, b, combine_partial_types).map_err(|_| {
let typ_a = self.get(a).clone();
let str_a = typ_a.to_string(self);
let typ_b = self.get(b).clone();
let str_b = typ_b.to_string(self);
map_err(str_a, str_b)
})
}

pub fn expect_type(
&mut self,
a: UnionIndex,
t: Type<Meta<Span, UniqueSym<'p>>>,
map_err: impl FnOnce(String, String) -> TypeError,
) -> Result<UnionIndex, TypeError> {
let t_index = self.type_to_index(t);
self.expect_equal(a, t_index, map_err)
}

pub fn expect_partial_type(
&mut self,
a: UnionIndex,
t: PartialType<'p>,
map_err: impl FnOnce(String, String) -> TypeError,
) -> Result<UnionIndex, TypeError> {
let t_index = self.add(t);
self.expect_equal(a, t_index, map_err)
}

pub fn type_to_index(&mut self, t: Type<Meta<Span, UniqueSym<'p>>>) -> UnionIndex {
let pt = match t {
Type::I64 => PartialType::I64,
Type::U64 => PartialType::U64,
Type::Bool => PartialType::Bool,
Type::Unit => PartialType::Unit,
Type::Never => PartialType::Never,
Type::Fn { params, typ } => PartialType::Fn {
params: params
.into_iter()
.map(|param| self.type_to_index(param))
.collect(),
typ: self.type_to_index(*typ),
},
Type::Var { sym } => PartialType::Var { sym: sym.inner },
};

self.add(pt)
}
}
23 changes: 0 additions & 23 deletions compiler/src/passes/validate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,3 @@ impl FromStr for TLit {
})
}
}

pub fn type_to_index<'p>(
t: Type<Meta<Span, UniqueSym<'p>>>,
uf: &mut UnionFind<PartialType<'p>>,
) -> UnionIndex {
let pt = match t {
Type::I64 => PartialType::I64,
Type::U64 => PartialType::U64,
Type::Bool => PartialType::Bool,
Type::Unit => PartialType::Unit,
Type::Never => PartialType::Never,
Type::Fn { params, typ } => PartialType::Fn {
params: params
.into_iter()
.map(|param| type_to_index(param, uf))
.collect(),
typ: type_to_index(*typ, uf),
},
Type::Var { sym } => PartialType::Var { sym: sym.inner },
};

uf.add(pt)
}
Loading

0 comments on commit 00c89f0

Please sign in to comment.