Skip to content

Commit

Permalink
feat: support recursive class definition
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Aug 25, 2023
1 parent 418f31e commit f3b188e
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 93 deletions.
4 changes: 4 additions & 0 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,11 @@ impl Context {
// ({I: Int | True} :> Int) == true
// {N: Nat | ...} :> Int) == false
// ({I: Int | I >= 0} :> Int) == false
// {U(: Type)} :> { .x = {Int} }(== {{ .x = Int }}) == true
(Refinement(l), r) => {
if let Some(r) = r.to_singleton() {
return self.structural_supertype_of(lhs, &Type::Refinement(r));
}
if l.pred.mentions(&l.var) {
match l.pred.can_be_false() {
Some(true) => {
Expand Down
72 changes: 51 additions & 21 deletions crates/erg_compiler/context/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,13 @@ impl Context {
call.loc(),
self.caused_by(),
))),
_ => unreachable!(),
other => Err(EvalErrors::from(EvalError::feature_error(
self.cfg.input.clone(),
line!() as usize,
other.loc(),
&format!("const call: {other}"),
self.caused_by(),
))),
}
} else {
Err(EvalErrors::from(EvalError::not_const_expr(
Expand Down Expand Up @@ -1031,26 +1037,8 @@ impl Context {
line!(),
))
}),
Or | BitOr => match (lhs, rhs) {
(ValueObj::Bool(l), ValueObj::Bool(r)) => Ok(ValueObj::Bool(l || r)),
(ValueObj::Int(l), ValueObj::Int(r)) => Ok(ValueObj::Int(l | r)),
(ValueObj::Type(lhs), ValueObj::Type(rhs)) => Ok(self.eval_or_type(lhs, rhs)),
_ => Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
))),
},
And | BitAnd => match (lhs, rhs) {
(ValueObj::Bool(l), ValueObj::Bool(r)) => Ok(ValueObj::Bool(l && r)),
(ValueObj::Int(l), ValueObj::Int(r)) => Ok(ValueObj::Int(l & r)),
(ValueObj::Type(lhs), ValueObj::Type(rhs)) => Ok(self.eval_and_type(lhs, rhs)),
_ => Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
))),
},
Or | BitOr => self.eval_or(lhs, rhs),
And | BitAnd => self.eval_and(lhs, rhs),
BitXor => match (lhs, rhs) {
(ValueObj::Bool(l), ValueObj::Bool(r)) => Ok(ValueObj::Bool(l ^ r)),
(ValueObj::Int(l), ValueObj::Int(r)) => Ok(ValueObj::Int(l ^ r)),
Expand All @@ -1068,6 +1056,27 @@ impl Context {
}
}

fn eval_or(&self, lhs: ValueObj, rhs: ValueObj) -> EvalResult<ValueObj> {
match (lhs, rhs) {
(ValueObj::Bool(l), ValueObj::Bool(r)) => Ok(ValueObj::Bool(l || r)),
(ValueObj::Int(l), ValueObj::Int(r)) => Ok(ValueObj::Int(l | r)),
(ValueObj::Type(lhs), ValueObj::Type(rhs)) => Ok(self.eval_or_type(lhs, rhs)),
(lhs, rhs) => {
let lhs = self.convert_value_into_type(lhs).ok();
let rhs = self.convert_value_into_type(rhs).ok();
if let Some((l, r)) = lhs.zip(rhs) {
self.eval_or(ValueObj::builtin_type(l), ValueObj::builtin_type(r))
} else {
Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
)))
}
}
}
}

fn eval_or_type(&self, lhs: TypeObj, rhs: TypeObj) -> ValueObj {
match (lhs, rhs) {
(
Expand Down Expand Up @@ -1101,6 +1110,27 @@ impl Context {
}
}

fn eval_and(&self, lhs: ValueObj, rhs: ValueObj) -> EvalResult<ValueObj> {
match (lhs, rhs) {
(ValueObj::Bool(l), ValueObj::Bool(r)) => Ok(ValueObj::Bool(l && r)),
(ValueObj::Int(l), ValueObj::Int(r)) => Ok(ValueObj::Int(l & r)),
(ValueObj::Type(lhs), ValueObj::Type(rhs)) => Ok(self.eval_and_type(lhs, rhs)),
(lhs, rhs) => {
let lhs = self.convert_value_into_type(lhs).ok();
let rhs = self.convert_value_into_type(rhs).ok();
if let Some((l, r)) = lhs.zip(rhs) {
self.eval_and(ValueObj::builtin_type(l), ValueObj::builtin_type(r))
} else {
Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
)))
}
}
}
}

fn eval_and_type(&self, lhs: TypeObj, rhs: TypeObj) -> ValueObj {
match (lhs, rhs) {
(
Expand Down
6 changes: 3 additions & 3 deletions crates/erg_compiler/context/initialize/const_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ pub(crate) fn class_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<
match base {
Some(value) => {
if let Some(base) = value.as_type(ctx) {
Ok(ValueObj::gen_t(GenTypeObj::class(t, Some(base), impls)).into())
Ok(ValueObj::gen_t(GenTypeObj::class(t, Some(base), impls, true)).into())
} else {
Err(type_mismatch("type", value, "Base"))
}
}
None => Ok(ValueObj::gen_t(GenTypeObj::class(t, None, impls)).into()),
None => Ok(ValueObj::gen_t(GenTypeObj::class(t, None, impls, true)).into()),
}
}

Expand Down Expand Up @@ -133,7 +133,7 @@ pub(crate) fn trait_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<
let impls = args.remove_left_or_key("Impl");
let impls = impls.map(|v| v.as_type(ctx).unwrap());
let t = mono(ctx.name.clone());
Ok(ValueObj::gen_t(GenTypeObj::trait_(t, req, impls)).into())
Ok(ValueObj::gen_t(GenTypeObj::trait_(t, req, impls, true)).into())
}

/// Base: Type, Impl := Type -> Patch
Expand Down
138 changes: 96 additions & 42 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -907,13 +907,46 @@ impl Context {
Ok(())
}

// To allow forward references and recursive definitions
pub(crate) fn preregister(&mut self, block: &ast::Block) -> TyCheckResult<()> {
pub(crate) fn preregister_const(&mut self, block: &ast::Block) -> TyCheckResult<()> {
let mut total_errs = TyCheckErrors::empty();
for expr in block.iter() {
match expr {
ast::Expr::Def(def) => {
if let Err(errs) = self.preregister_def(def) {
if let Err(errs) = self.preregister_const_def(def) {
total_errs.extend(errs);
}
}
ast::Expr::ClassDef(class_def) => {
if let Err(errs) = self.preregister_const_def(&class_def.def) {
total_errs.extend(errs);
}
}
ast::Expr::PatchDef(patch_def) => {
if let Err(errs) = self.preregister_const_def(&patch_def.def) {
total_errs.extend(errs);
}
}
ast::Expr::Dummy(dummy) => {
if let Err(errs) = self.preregister_const(&dummy.exprs) {
total_errs.extend(errs);
}
}
_ => {}
}
}
if total_errs.is_empty() {
Ok(())
} else {
Err(total_errs)
}
}

pub(crate) fn register_const(&mut self, block: &ast::Block) -> TyCheckResult<()> {
let mut total_errs = TyCheckErrors::empty();
for expr in block.iter() {
match expr {
ast::Expr::Def(def) => {
if let Err(errs) = self.register_const_def(def) {
total_errs.extend(errs);
}
if def.def_kind().is_import() {
Expand All @@ -923,17 +956,17 @@ impl Context {
}
}
ast::Expr::ClassDef(class_def) => {
if let Err(errs) = self.preregister_def(&class_def.def) {
if let Err(errs) = self.register_const_def(&class_def.def) {
total_errs.extend(errs);
}
}
ast::Expr::PatchDef(patch_def) => {
if let Err(errs) = self.preregister_def(&patch_def.def) {
if let Err(errs) = self.register_const_def(&patch_def.def) {
total_errs.extend(errs);
}
}
ast::Expr::Dummy(dummy) => {
if let Err(errs) = self.preregister(&dummy.exprs) {
if let Err(errs) = self.register_const(&dummy.exprs) {
total_errs.extend(errs);
}
}
Expand Down Expand Up @@ -988,7 +1021,43 @@ impl Context {
res
}

pub(crate) fn preregister_def(&mut self, def: &ast::Def) -> TyCheckResult<()> {
fn preregister_const_def(&mut self, def: &ast::Def) -> TyCheckResult<()> {
match &def.sig {
ast::Signature::Var(var) if var.is_const() => {
let Some(ast::Expr::Call(call)) = def.body.block.first() else {
return Ok(());
};
self.preregister_type(var, call)
}
_ => Ok(()),
}
}

fn preregister_type(&mut self, var: &ast::VarSignature, call: &ast::Call) -> TyCheckResult<()> {
match call.obj.as_ref() {
ast::Expr::Accessor(ast::Accessor::Ident(ident)) => match &ident.inspect()[..] {
"Class" => {
let ident = var.ident().unwrap();
let t = Type::Mono(format!("{}{ident}", self.name).into());
let class = GenTypeObj::class(t, None, None, false);
let class = ValueObj::Type(TypeObj::Generated(class));
self.register_gen_const(ident, class, false)
}
"Trait" => {
let ident = var.ident().unwrap();
let t = Type::Mono(format!("{}{ident}", self.name).into());
let trait_ =
GenTypeObj::trait_(t, TypeObj::builtin_type(Type::Failure), None, false);
let trait_ = ValueObj::Type(TypeObj::Generated(trait_));
self.register_gen_const(ident, trait_, false)
}
_ => Ok(()),
},
_ => Ok(()),
}
}

pub(crate) fn register_const_def(&mut self, def: &ast::Def) -> TyCheckResult<()> {
let id = Some(def.body.id);
let __name__ = def.sig.ident().map(|i| i.inspect()).unwrap_or(UBAR);
match &def.sig {
Expand Down Expand Up @@ -1278,7 +1347,10 @@ impl Context {
alias: bool,
) -> CompileResult<()> {
let vis = self.instantiate_vis_modifier(&ident.vis)?;
if self.rec_get_const_obj(ident.inspect()).is_some() && vis.is_private() {
let inited = self
.rec_get_const_obj(ident.inspect())
.is_some_and(|v| v.is_inited());
if inited && vis.is_private() {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
Expand Down Expand Up @@ -1457,14 +1529,13 @@ impl Context {
2,
self.level,
);
let Some(TypeObj::Builtin {
if let Some(TypeObj::Builtin {
t: Type::Record(req),
..
}) = gen.base_or_sup()
else {
todo!("{gen}")
};
self.register_instance_attrs(&mut ctx, req)?;
{
self.register_instance_attrs(&mut ctx, req)?;
}
self.register_gen_mono_type(ident, gen, ctx, Const)
} else {
feature_error!(
Expand Down Expand Up @@ -1635,15 +1706,10 @@ impl Context {
meta_t: Type,
) -> CompileResult<()> {
let vis = self.instantiate_vis_modifier(&ident.vis)?;
if self.mono_types.contains_key(ident.inspect()) {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
ident.loc(),
self.caused_by(),
ident.inspect(),
)))
} else if self.rec_get_const_obj(ident.inspect()).is_some() && vis.is_private() {
let inited = self
.rec_get_const_obj(ident.inspect())
.is_some_and(|v| v.is_inited());
if inited && vis.is_private() {
// TODO: display where defined
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
Expand Down Expand Up @@ -1682,16 +1748,10 @@ impl Context {
muty: Mutability,
) -> CompileResult<()> {
let vis = self.instantiate_vis_modifier(&ident.vis)?;
// FIXME: recursive search
if self.mono_types.contains_key(ident.inspect()) {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
ident.loc(),
self.caused_by(),
ident.inspect(),
)))
} else if self.rec_get_const_obj(ident.inspect()).is_some() && vis.is_private() {
let inited = self
.rec_get_const_obj(ident.inspect())
.is_some_and(|v| v.is_inited());
if inited && vis.is_private() {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
Expand Down Expand Up @@ -1732,16 +1792,10 @@ impl Context {
muty: Mutability,
) -> CompileResult<()> {
let vis = self.instantiate_vis_modifier(&ident.vis)?;
// FIXME: recursive search
if self.poly_types.contains_key(ident.inspect()) {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
ident.loc(),
self.caused_by(),
ident.inspect(),
)))
} else if self.rec_get_const_obj(ident.inspect()).is_some() && vis.is_private() {
let inited = self
.rec_get_const_obj(ident.inspect())
.is_some_and(|v| v.is_inited());
if inited && vis.is_private() {
Err(CompileErrors::from(CompileError::reassign_error(
self.cfg.input.clone(),
line!() as usize,
Expand Down
8 changes: 6 additions & 2 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,12 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
self.sub_unify(maybe_sub, &Type::Refinement(sup))?;
}
(sub, Refinement(_)) => {
let sub = sub.clone().into_refinement();
self.sub_unify(&Type::Refinement(sub), maybe_sup)?;
if let Some(sub) = sub.to_singleton() {
self.sub_unify(&Type::Refinement(sub), maybe_sup)?;
} else {
let sub = sub.clone().into_refinement();
self.sub_unify(&Type::Refinement(sub), maybe_sup)?;
}
}
(Subr(_) | Record(_), Type) => {}
// REVIEW: correct?
Expand Down
14 changes: 9 additions & 5 deletions crates/erg_compiler/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,14 +772,18 @@ impl ASTLowerer {
let (t, ty_obj) = match t {
Type::ClassType => {
let t = mono(format!("{}{ident}", self.module.context.path()));
let ty_obj = GenTypeObj::class(t.clone(), None, None);
let ty_obj = GenTypeObj::class(t.clone(), None, None, true);
let t = v_enum(set! { ValueObj::builtin_class(t) });
(t, Some(ty_obj))
}
Type::TraitType => {
let t = mono(format!("{}{ident}", self.module.context.path()));
let ty_obj =
GenTypeObj::trait_(t.clone(), TypeObj::builtin_type(Type::Uninited), None);
let ty_obj = GenTypeObj::trait_(
t.clone(),
TypeObj::builtin_type(Type::Uninited),
None,
true,
);
let t = v_enum(set! { ValueObj::builtin_trait(t) });
(t, Some(ty_obj))
}
Expand All @@ -793,7 +797,7 @@ impl ASTLowerer {
})
.collect();
let t = poly(format!("{}{ident}", self.module.context.path()), params);
let ty_obj = GenTypeObj::class(t.clone(), None, None);
let ty_obj = GenTypeObj::class(t.clone(), None, None, true);
let t = v_enum(set! { ValueObj::builtin_class(t) });
(t, Some(ty_obj))
}
Expand Down Expand Up @@ -875,7 +879,7 @@ impl ASTLowerer {

pub(crate) fn declare_module(&mut self, ast: AST) -> HIR {
let mut module = hir::Module::with_capacity(ast.module.len());
let _ = self.module.context.preregister(ast.module.block());
let _ = self.module.context.register_const(ast.module.block());
for chunk in ast.module.into_iter() {
match self.declare_chunk(chunk, false) {
Ok(chunk) => {
Expand Down
Loading

0 comments on commit f3b188e

Please sign in to comment.