From d902bd79df21abf5b3c6f157eb18f2faa61a0e6e Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Sun, 11 Aug 2024 14:50:02 -0500 Subject: [PATCH] Add an actual subtyping relation --- src/hir/definitions.rs | 13 +-- src/hir/monomorphisation.rs | 28 ++++-- src/lifetimes/mod.rs | 133 ---------------------------- src/nameresolution/mod.rs | 24 +++-- src/parser/ast.rs | 1 + src/types/mod.rs | 57 ++++++++++-- src/types/typechecker.rs | 171 +++++++++++++++++++++--------------- src/types/typeprinter.rs | 45 ++++++---- 8 files changed, 223 insertions(+), 249 deletions(-) diff --git a/src/hir/definitions.rs b/src/hir/definitions.rs index 199ea6bf..1625a025 100644 --- a/src/hir/definitions.rs +++ b/src/hir/definitions.rs @@ -63,9 +63,9 @@ impl std::hash::Hash for DefinitionType { types::Type::TypeVariable(_) => (), // Do nothing types::Type::Function(_) => (), types::Type::TypeApplication(_, _) => (), - types::Type::Ref(shared, mutable, _) => { - shared.hash(state); - mutable.hash(state); + types::Type::Ref { sharedness, mutability, lifetime: _ } => { + sharedness.hash(state); + mutability.hash(state); }, types::Type::Struct(field_names, _) => { for name in field_names { @@ -77,6 +77,7 @@ impl std::hash::Hash for DefinitionType { id.hash(state); } }, + types::Type::Tag(tag) => tag.hash(state), } }) } @@ -98,8 +99,9 @@ fn definition_type_eq(a: &types::Type, b: &types::Type) -> bool { // This will monomorphize separate definitions for polymorphically-owned references // which is undesired. Defaulting them to shared/owned though can change behavior // if traits are involved. - (Type::Ref(shared1, mutable1, _), Type::Ref(shared2, mutable2, _)) => { - shared1 == shared2 && mutable1 == mutable2 + (Type::Ref { sharedness: shared1, mutability: mutable1, .. }, + Type::Ref { sharedness: shared2, mutability: mutable2, .. }) => { + definition_type_eq(shared1, shared2) && definition_type_eq(mutable1, mutable2) }, (Type::Function(f1), Type::Function(f2)) => { if f1.parameters.len() != f2.parameters.len() { @@ -137,6 +139,7 @@ fn definition_type_eq(a: &types::Type, b: &types::Type) -> bool { id1 == id2 && args1.iter().zip(args2).all(|(t1, t2)| definition_type_eq(t1, t2)) }) }, + (Type::Tag(tag1), Type::Tag(tag2)) => tag1 == tag2, (othera, otherb) => { assert_ne!(std::mem::discriminant(othera), std::mem::discriminant(otherb), "ICE: Missing match case"); false diff --git a/src/hir/monomorphisation.rs b/src/hir/monomorphisation.rs index fd2f9874..92fdfc8b 100644 --- a/src/hir/monomorphisation.rs +++ b/src/hir/monomorphisation.rs @@ -142,12 +142,12 @@ impl<'c> Context<'c> { let fuel = fuel - 1; match &self.cache.type_bindings[id.0] { - Bound(TypeVariable(id2) | Ref(_, _, id2)) => self.find_binding(*id2, fuel), + Bound(TypeVariable(id2)) => self.find_binding(*id2, fuel), Bound(binding) => Ok(binding), Unbound(..) => { for bindings in self.monomorphisation_bindings.iter().rev() { match bindings.get(&id) { - Some(TypeVariable(id2) | Ref(_, _, id2)) => return self.find_binding(*id2, fuel), + Some(TypeVariable(id2)) => return self.find_binding(*id2, fuel), Some(binding) => return Ok(binding), None => (), } @@ -204,7 +204,12 @@ impl<'c> Context<'c> { let args = fmap(args, |arg| self.follow_all_bindings_inner(arg, fuel)); TypeApplication(Box::new(con), args) }, - Ref(..) => typ.clone(), + Ref { mutability, sharedness, lifetime } => { + let mutability = Box::new(self.follow_all_bindings_inner(mutability, fuel)); + let sharedness = Box::new(self.follow_all_bindings_inner(sharedness, fuel)); + let lifetime = Box::new(self.follow_all_bindings_inner(lifetime, fuel)); + Ref { mutability, sharedness, lifetime } + }, Struct(fields, id) => match self.find_binding(*id, fuel) { Ok(binding) => self.follow_all_bindings_inner(binding, fuel), Err(_) => { @@ -217,6 +222,7 @@ impl<'c> Context<'c> { }, }, Effects(effects) => self.follow_all_effect_bindings_inner(effects, fuel), + Tag(tag) => Tag(*tag), } } @@ -315,6 +321,9 @@ impl<'c> Context<'c> { Primitive(FloatType) => { unreachable!("'Float' type constructor without arguments found during size_of_type") }, + Tag(tag) => { + unreachable!("'{}' found during size_of_type", tag) + } Function(..) => Self::ptr_size(), @@ -346,7 +355,7 @@ impl<'c> Context<'c> { _ => unreachable!("Kind error inside size_of_type"), }, - Ref(..) => Self::ptr_size(), + Ref { .. } => Self::ptr_size(), Struct(fields, rest) => { if let Ok(binding) = self.find_binding(*rest, RECURSION_LIMIT) { let binding = binding.clone(); @@ -519,7 +528,7 @@ impl<'c> Context<'c> { let typ = self.follow_bindings_shallow(typ); match typ { - Ok(Primitive(PrimitiveType::Ptr) | Ref(..)) => Type::Primitive(hir::PrimitiveType::Pointer), + Ok(Primitive(PrimitiveType::Ptr) | Ref { .. }) => Type::Primitive(hir::PrimitiveType::Pointer), Ok(Primitive(PrimitiveType::IntegerType)) => { if self.is_type_variable(&args[0]) { // Default to i32 @@ -553,11 +562,14 @@ impl<'c> Context<'c> { } }, - Ref(..) => { + Ref { .. } => { unreachable!( "Kind error during monomorphisation. Attempted to translate a `ref` without a type argument" ) }, + Tag(tag) => { + unreachable!("Kind error during monomorphisation. Attempted to translate a `{}` as a type", tag) + } Struct(fields, rest) => { if let Ok(binding) = self.find_binding(*rest, fuel) { let binding = binding.clone(); @@ -1517,7 +1529,7 @@ impl<'c> Context<'c> { TypeApplication(typ, args) => { match typ.as_ref() { // Pass through ref types transparently - types::Type::Ref(..) => self.get_field_index(field_name, &args[0]), + types::Type::Ref { .. } => self.get_field_index(field_name, &args[0]), // These last 2 cases are the same. They're duplicated to avoid another follow_bindings_shallow call. typ => self.get_field_index(field_name, typ), } @@ -1551,7 +1563,7 @@ impl<'c> Context<'c> { let ref_type = match lhs_type { types::Type::TypeApplication(constructor, args) => match self.follow_bindings_shallow(constructor.as_ref()) { - Ok(types::Type::Ref(..)) => Some(self.convert_type(&args[0])), + Ok(types::Type::Ref { .. }) => Some(self.convert_type(&args[0])), _ => None, }, _ => None, diff --git a/src/lifetimes/mod.rs b/src/lifetimes/mod.rs index 69d98307..a5b54c2c 100644 --- a/src/lifetimes/mod.rs +++ b/src/lifetimes/mod.rs @@ -1,138 +1,5 @@ use crate::cache::ModuleCache; use crate::parser::ast::Ast; -use crate::types::TypeVariableId; - -/// A lifetime variable is represented simply as a type variable for ease of unification -/// during the type inference pass. -pub type LifetimeVariableId = TypeVariableId; - -// struct LifetimeAnalyzer { -// pub level: StackFrameIndex, -// -// /// Map from RegionVariableId -> StackFrame -// /// Contains the stack frame index each region should be allocated in -// pub lifetimes: Vec, -// } -// -// struct StackFrameIndex(usize); #[allow(unused)] pub fn infer<'c>(_ast: &mut Ast<'c>, _cache: &mut ModuleCache<'c>) {} - -// trait InferableLifetime { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>); -// } -// -// impl<'ast> InferableLifetime for Ast<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// dispatch_on_expr!(self, InferableLifetime::infer_lifetime, analyzer, cache) -// } -// } -// -// impl<'ast> InferableLifetime for ast::Literal<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// // Do nothing: literals cannot contain a ref type -// } -// } -// -// impl<'ast> InferableLifetime for ast::Variable<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Lambda<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::FunctionCall<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Definition<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::If<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Match<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::TypeDefinition<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::TypeAnnotation<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Import<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::TraitDefinition<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::TraitImpl<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Return<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Sequence<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Extern<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::MemberAccess<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Tuple<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } -// -// impl<'ast> InferableLifetime for ast::Assignment<'ast> { -// fn infer_lifetime<'c>(&mut self, analyzer: &mut LifetimeAnalyzer, cache: &mut ModuleCache<'c>) { -// -// } -// } diff --git a/src/nameresolution/mod.rs b/src/nameresolution/mod.rs index ac67ec9d..ea9466d1 100644 --- a/src/nameresolution/mod.rs +++ b/src/nameresolution/mod.rs @@ -50,7 +50,7 @@ use crate::types::traits::ConstraintSignature; use crate::types::typed::Typed; use crate::types::{ Field, FunctionType, GeneralizedType, LetBindingLevel, PrimitiveType, Type, TypeConstructor, TypeInfoBody, - TypeInfoId, TypeVariableId, INITIAL_LEVEL, STRING_TYPE, + TypeInfoId, TypeVariableId, INITIAL_LEVEL, STRING_TYPE, TypeTag, }; use crate::util::{fmap, timing, trustme}; @@ -589,9 +589,10 @@ impl<'c> NameResolver { Type::TypeVariable(_) => 0, Type::UserDefined(id) => cache[*id].args.len(), Type::TypeApplication(_, _) => 0, - Type::Ref(..) => 1, + Type::Ref { .. } => 1, Type::Struct(_, _) => 0, Type::Effects(_) => 0, + Type::Tag(_) => 0, } } @@ -739,14 +740,27 @@ impl<'c> NameResolver { Type::TypeApplication(Box::new(pair), args) }, - ast::Type::Reference(sharednes, mutability, _) => { + ast::Type::Reference(sharedness, mutability, _) => { // When translating ref types, all have a hidden lifetime variable that is unified // under the hood by the compiler to determine the reference's stack lifetime. // This is never able to be manually specified by the programmer, so we use // next_type_variable_id on the cache rather than the NameResolver's version which // would add a name into scope. - let lifetime_variable = cache.next_type_variable_id(self.let_binding_level); - Type::Ref(*sharednes, *mutability, lifetime_variable) + let lifetime = Box::new(cache.next_type_variable(self.let_binding_level)); + + let sharedness = Box::new(match sharedness { + ast::Sharedness::Polymorphic => cache.next_type_variable(self.let_binding_level), + ast::Sharedness::Shared => Type::Tag(TypeTag::Shared), + ast::Sharedness::Owned => Type::Tag(TypeTag::Owned), + }); + + let mutability = Box::new(match mutability { + ast::Mutability::Polymorphic => cache.next_type_variable(self.let_binding_level), + ast::Mutability::Immutable => Type::Tag(TypeTag::Immutable), + ast::Mutability::Mutable => Type::Tag(TypeTag::Mutable), + }); + + Type::Ref { sharedness, mutability, lifetime } }, } } diff --git a/src/parser/ast.rs b/src/parser/ast.rs index 2aa8e402..72dae648 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -223,6 +223,7 @@ pub enum Sharedness { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Mutability { + #[allow(unused)] Polymorphic, Immutable, Mutable, diff --git a/src/types/mod.rs b/src/types/mod.rs index 05f7422c..c0e0fe35 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -9,9 +9,8 @@ use std::collections::BTreeMap; use crate::cache::{DefinitionInfoId, ModuleCache}; use crate::error::location::{Locatable, Location}; use crate::lexer::token::{FloatKind, IntegerKind}; -use crate::parser::ast::{Mutability, Sharedness}; use crate::util::fmap; -use crate::{lifetimes, util}; +use crate::util; use self::typeprinter::TypePrinter; use crate::types::effects::EffectSet; @@ -103,7 +102,7 @@ pub enum Type { /// A region-allocated reference to some data. /// Contains a region variable that is unified with other refs during type /// inference. All these refs will be allocated in the same region. - Ref(Sharedness, Mutability, lifetimes::LifetimeVariableId), + Ref { mutability: Box, sharedness: Box, lifetime: Box }, /// A (row-polymorphic) struct type. Unlike normal rho variables, /// the type variable used here replaces the entire type if bound. @@ -115,6 +114,22 @@ pub enum Type { /// are included in it since they are still valid in a type position /// most notably when substituting type variables for effects. Effects(EffectSet), + + /// Tags are any type which isn't a valid type by itself but may be inside + /// a larger type. For example, `shared` is not a type, but a polymorphic + /// reference's type variable may resolve to a shared reference. + Tag(TypeTag), +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub enum TypeTag { + // References can be polymorphic in their ownership or mutability. + // When they are, they hold type variables which later resolve to one + // of these variants. + Owned, + Shared, + Mutable, + Immutable, } #[derive(Debug, Clone)] @@ -191,13 +206,14 @@ impl Type { use Type::*; match self { Primitive(_) => None, - Ref(..) => None, + Ref { .. } => None, Function(function) => function.return_type.union_constructor_variants(cache), TypeApplication(typ, _) => typ.union_constructor_variants(cache), UserDefined(id) => cache.type_infos[id.0].union_variants(), TypeVariable(_) => unreachable!("Constructors should always have concrete types"), Struct(_, _) => None, Effects(_) => None, + Tag(_) => None, } } @@ -224,6 +240,7 @@ impl Type { match self { Type::Primitive(_) => (), Type::UserDefined(_) => (), + Type::Tag(_) => (), Type::Function(function) => { for parameter in &function.parameters { @@ -232,10 +249,15 @@ impl Type { function.environment.traverse_rec(cache, f); function.return_type.traverse_rec(cache, f); }, - Type::TypeVariable(id) | Type::Ref(_, _, id) => match &cache.type_bindings[id.0] { + Type::TypeVariable(id) => match &cache.type_bindings[id.0] { TypeBinding::Bound(binding) => binding.traverse_rec(cache, f), TypeBinding::Unbound(_, _) => (), }, + Type::Ref { sharedness, mutability, lifetime } => { + sharedness.traverse_rec(cache, f); + mutability.traverse_rec(cache, f); + lifetime.traverse_rec(cache, f); + } Type::TypeApplication(constructor, args) => { constructor.traverse_rec(cache, f); for arg in args { @@ -274,7 +296,7 @@ impl Type { Type::Primitive(_) => (), Type::UserDefined(_) => (), Type::TypeVariable(_) => (), - Type::Ref(..) => (), + Type::Tag(_) => (), Type::Function(function) => { for parameter in &function.parameters { @@ -301,6 +323,10 @@ impl Type { typ.traverse_no_follow_rec(f); } }, + Type::Ref { sharedness, mutability, lifetime: _ } => { + sharedness.traverse_no_follow_rec(f); + mutability.traverse_no_follow_rec(f); + }, } } @@ -325,7 +351,12 @@ impl Type { let args = fmap(args, |arg| arg.approx_to_string()); format!("({} {})", constructor, args.join(" ")) }, - Type::Ref(shared, mutable, id) => format!("&'{} {}{}", id.0, shared, mutable), + Type::Ref { sharedness, mutability, lifetime } => { + let shared = sharedness.approx_to_string(); + let mutable = mutability.approx_to_string(); + let lifetime = lifetime.approx_to_string(); + format!("{}{} '{}", mutable, shared, lifetime) + } Type::Struct(fields, id) => { let fields = fmap(fields, |(name, typ)| format!("{}: {}", name, typ.approx_to_string())); format!("{{ {}, ..tv{} }}", fields.join(", "), id.0) @@ -341,6 +372,18 @@ impl Type { format!("can {}, ..tv{}", effects.join(", "), set.replacement.0) } }, + Type::Tag(tag) => tag.to_string(), + } + } +} + +impl std::fmt::Display for TypeTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeTag::Owned => write!(f, "owned"), + TypeTag::Shared => write!(f, "shared"), + TypeTag::Mutable => write!(f, "!"), + TypeTag::Immutable => write!(f, "&"), } } } diff --git a/src/types/typechecker.rs b/src/types/typechecker.rs index 6d6853e1..67f51e5a 100644 --- a/src/types/typechecker.rs +++ b/src/types/typechecker.rs @@ -17,7 +17,7 @@ //! Most of this file is translated from: https://github.com/jfecher/algorithm-j //! That repository may be a good starting place for those new to type inference. //! For those already familiar with type inference or more interested in ante's -//! internals, the reccomended starting place while reading this file is the +//! internals, the recommended starting place while reading this file is the //! `Inferable` trait and its impls for each node. From there, you can see what //! type inference does for each node type and inspect any helpers that are used. //! @@ -30,7 +30,7 @@ use crate::cache::{DefinitionInfoId, DefinitionKind, EffectInfoId, ModuleCache, use crate::cache::{ImplScopeId, VariableId}; use crate::error::location::{Locatable, Location}; use crate::error::{Diagnostic, DiagnosticKind as D, TypeErrorKind, TypeErrorKind as TE}; -use crate::parser::ast::{self, ClosureEnvironment, Mutability, Sharedness}; +use crate::parser::ast::{self, ClosureEnvironment}; use crate::types::traits::{RequiredTrait, TraitConstraint, TraitConstraints}; use crate::types::typed::Typed; use crate::types::EffectSet; @@ -46,7 +46,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use super::mutual_recursion::{definition_is_mutually_recursive, try_generalize_definition}; use super::traits::{Callsite, ConstraintSignature, TraitConstraintId}; -use super::{GeneralizedType, TypeInfoBody}; +use super::{GeneralizedType, TypeInfoBody, TypeTag}; /// The current LetBindingLevel we are at. /// This increases by 1 whenever we enter the rhs of a `ast::Definition` and decreases @@ -153,8 +153,10 @@ pub fn type_application_bindings(info: &TypeInfo<'_>, typeargs: &[Type], cache: /// Given `a` returns `ref a` fn ref_of(typ: Type, cache: &mut ModuleCache) -> Type { - let new_var = next_type_variable_id(cache); - let constructor = Box::new(Type::Ref(Sharedness::Polymorphic, Mutability::Polymorphic, new_var)); + let sharedness = Box::new(next_type_variable(cache)); + let mutability = Box::new(next_type_variable(cache)); + let lifetime = Box::new(next_type_variable(cache)); + let constructor = Box::new(Type::Ref { sharedness, mutability, lifetime }); TypeApplication(constructor, vec![typ]) } @@ -188,8 +190,9 @@ pub fn replace_all_typevars_with_bindings( ) -> Type { match typ { Primitive(p) => Primitive(*p), + Tag(tag) => Tag(*tag), - TypeVariable(id) => replace_typevar_with_binding(*id, new_bindings, TypeVariable, cache), + TypeVariable(id) => replace_typevar_with_binding(*id, new_bindings, cache), Function(function) => { let parameters = fmap(&function.parameters, |parameter| { @@ -203,14 +206,11 @@ pub fn replace_all_typevars_with_bindings( }, UserDefined(id) => UserDefined(*id), - // We must recurse on the lifetime variable since they are unified as normal type variables - Ref(sharedness, mutability, lifetime) => { - let make_ref = |new_lifetime| Ref(*sharedness, *mutability, new_lifetime); - match replace_typevar_with_binding(*lifetime, new_bindings, make_ref, cache) { - TypeVariable(new_lifetime) => make_ref(new_lifetime), - new_ref @ Ref(..) => new_ref, - _ => unreachable!("Bound Ref lifetime to non-lifetime type"), - } + Ref { mutability, sharedness, lifetime } => { + let mutability = Box::new(replace_all_typevars_with_bindings(mutability, new_bindings, cache)); + let sharedness = Box::new(replace_all_typevars_with_bindings(sharedness, new_bindings, cache)); + let lifetime = Box::new(replace_all_typevars_with_bindings(lifetime, new_bindings, cache)); + Ref { sharedness, mutability, lifetime } }, TypeApplication(typ, args) => { @@ -239,11 +239,8 @@ pub fn replace_all_typevars_with_bindings( /// If the given TypeVariableId is unbound then return the matching binding in new_bindings. /// If there is no binding found, instantiate a new type variable and use that. -/// -/// `default` should be either TypeVariable or Ref and controls which kind of type gets -/// created that wraps the newly-instantiated TypeVariableId if one is made. fn replace_typevar_with_binding( - id: TypeVariableId, new_bindings: &mut TypeBindings, default: impl FnOnce(TypeVariableId) -> Type, + id: TypeVariableId, new_bindings: &mut TypeBindings, cache: &mut ModuleCache<'_>, ) -> Type { if let Bound(typ) = &cache.type_bindings[id.0] { @@ -252,7 +249,7 @@ fn replace_typevar_with_binding( var.clone() } else { let new_typevar = next_type_variable_id(cache); - let typ = default(new_typevar); + let typ = Type::TypeVariable(new_typevar); new_bindings.insert(id, typ.clone()); typ } @@ -266,8 +263,9 @@ fn replace_typevar_with_binding( pub fn bind_typevars(typ: &Type, type_bindings: &TypeBindings, cache: &ModuleCache<'_>) -> Type { match typ { Primitive(p) => Primitive(*p), + Tag(tag) => Tag(*tag), - TypeVariable(id) => bind_typevar(*id, type_bindings, TypeVariable, cache), + TypeVariable(id) => bind_typevar(*id, type_bindings, cache), Function(function) => { let parameters = fmap(&function.parameters, |parameter| bind_typevars(parameter, type_bindings, cache)); @@ -279,13 +277,11 @@ pub fn bind_typevars(typ: &Type, type_bindings: &TypeBindings, cache: &ModuleCac }, UserDefined(id) => UserDefined(*id), - Ref(sharedness, mutability, lifetime) => { - let make_ref = |lifetime| Ref(*sharedness, *mutability, lifetime); - match bind_typevar(*lifetime, type_bindings, make_ref, cache) { - TypeVariable(new_lifetime) => make_ref(new_lifetime), - new_ref @ Ref(..) => new_ref, - _ => unreachable!("Bound Ref lifetime to non-lifetime type"), - } + Ref { mutability, sharedness, lifetime } => { + let mutability = Box::new(bind_typevars(mutability, type_bindings, cache)); + let sharedness = Box::new(bind_typevars(sharedness, type_bindings, cache)); + let lifetime = Box::new(bind_typevars(lifetime, type_bindings, cache)); + Ref { sharedness, mutability, lifetime } }, TypeApplication(typ, args) => { @@ -327,9 +323,9 @@ pub fn bind_typevars(typ: &Type, type_bindings: &TypeBindings, cache: &ModuleCac /// Helper for bind_typevars which binds a single TypeVariableId if it is Unbound /// and it is found in the type_bindings. If a type_binding wasn't found, a -/// default TypeVariable or Ref is constructed by passing the relevant constructor to `default`. +/// default TypeVariable is constructed. fn bind_typevar( - id: TypeVariableId, type_bindings: &TypeBindings, default: impl FnOnce(TypeVariableId) -> Type, + id: TypeVariableId, type_bindings: &TypeBindings, cache: &ModuleCache<'_>, ) -> Type { // TODO: This ordering of checking type_bindings first is important. @@ -342,7 +338,7 @@ fn bind_typevar( if let Bound(typ) = &cache.type_bindings[id.0] { bind_typevars(&typ.clone(), type_bindings, cache) } else { - default(id) + Type::TypeVariable(id) } }, } @@ -354,6 +350,7 @@ pub fn contains_any_typevars_from_list(typ: &Type, list: &[TypeVariableId], cach match typ { Primitive(_) => false, UserDefined(_) => false, + Tag(_) => false, TypeVariable(id) => type_variable_contains_any_typevars_from_list(*id, list, cache), @@ -364,7 +361,11 @@ pub fn contains_any_typevars_from_list(typ: &Type, list: &[TypeVariableId], cach || contains_any_typevars_from_list(&function.effects, list, cache) }, - Ref(_, _, lifetime) => type_variable_contains_any_typevars_from_list(*lifetime, list, cache), + Ref { mutability, sharedness, lifetime } => { + contains_any_typevars_from_list(mutability, list, cache) + || contains_any_typevars_from_list(sharedness, list, cache) + || contains_any_typevars_from_list(lifetime, list, cache) + } TypeApplication(typ, args) => { contains_any_typevars_from_list(typ, list, cache) @@ -555,6 +556,7 @@ pub(super) fn occurs( match typ { Primitive(_) => OccursResult::does_not_occur(), UserDefined(_) => OccursResult::does_not_occur(), + Tag(_) => OccursResult::does_not_occur(), TypeVariable(var_id) => typevars_match(id, level, *var_id, bindings, fuel, cache), Function(function) => occurs(id, level, &function.return_type, bindings, fuel, cache) @@ -563,7 +565,10 @@ pub(super) fn occurs( .then_all(&function.parameters, |param| occurs(id, level, param, bindings, fuel, cache)), TypeApplication(typ, args) => occurs(id, level, typ, bindings, fuel, cache) .then_all(args, |arg| occurs(id, level, arg, bindings, fuel, cache)), - Ref(_, _, lifetime) => typevars_match(id, level, *lifetime, bindings, fuel, cache), + Ref { mutability, sharedness, lifetime } => + occurs(id, level, mutability, bindings, fuel, cache) + .then(|| occurs(id, level, sharedness, bindings, fuel, cache)) + .then(|| occurs(id, level, lifetime, bindings, fuel, cache)), Struct(fields, var_id) => typevars_match(id, level, *var_id, bindings, fuel, cache) .then_all(fields.iter().map(|(_, typ)| typ), |field| occurs(id, level, field, bindings, fuel, cache)), Effects(effects) => effects.occurs(id, level, bindings, fuel, cache), @@ -590,7 +595,7 @@ pub(super) fn typevars_match( /// Returns what a given type is bound to, following all typevar links until it reaches an Unbound one. pub fn follow_bindings_in_cache_and_map(typ: &Type, bindings: &UnificationBindings, cache: &ModuleCache<'_>) -> Type { match typ { - TypeVariable(id) | Ref(_, _, id) => match find_binding(*id, bindings, cache) { + TypeVariable(id) => match find_binding(*id, bindings, cache) { Bound(typ) => follow_bindings_in_cache_and_map(&typ, bindings, cache), Unbound(..) => typ.clone(), }, @@ -600,7 +605,7 @@ pub fn follow_bindings_in_cache_and_map(typ: &Type, bindings: &UnificationBindin pub fn follow_bindings_in_cache(typ: &Type, cache: &ModuleCache<'_>) -> Type { match typ { - TypeVariable(id) | Ref(_, _, id) => match &cache.type_bindings[id.0] { + TypeVariable(id) => match &cache.type_bindings[id.0] { Bound(typ) => follow_bindings_in_cache(typ, cache), Unbound(..) => typ.clone(), }, @@ -619,9 +624,9 @@ pub fn follow_bindings_in_cache(typ: &Type, cache: &ModuleCache<'_>) -> Type { /// This function performs the bulk of the work for the various unification functions. #[allow(clippy::nonminimal_bool)] pub fn try_unify_with_bindings_inner<'b>( - t1: &Type, t2: &Type, bindings: &mut UnificationBindings, location: Location<'b>, cache: &mut ModuleCache<'b>, + actual: &Type, expected: &Type, bindings: &mut UnificationBindings, location: Location<'b>, cache: &mut ModuleCache<'b>, ) -> Result<(), ()> { - match (t1, t2) { + match (actual, expected) { (Primitive(p1), Primitive(p2)) if p1 == p2 => Ok(()), (UserDefined(id1), UserDefined(id2)) if id1 == id2 => Ok(()), @@ -632,9 +637,9 @@ pub fn try_unify_with_bindings_inner<'b>( // it to the minimum scope of type variables in b. This happens within the occurs check. // The unification of the LetBindingLevel here is a form of lifetime inference for the // typevar and is used during generalization to determine which variables to generalize. - (TypeVariable(id), _) => try_unify_type_variable_with_bindings(*id, t1, t2, bindings, location, cache), + (TypeVariable(id), _) => try_unify_type_variable_with_bindings(*id, actual, expected, true, bindings, location, cache), - (_, TypeVariable(id)) => try_unify_type_variable_with_bindings(*id, t2, t1, bindings, location, cache), + (_, TypeVariable(id)) => try_unify_type_variable_with_bindings(*id, expected, actual, false, bindings, location, cache), (Function(function1), Function(function2)) => { if function1.parameters.len() != function2.parameters.len() { @@ -651,7 +656,9 @@ pub fn try_unify_with_bindings_inner<'b>( try_unify_with_bindings_inner(a_arg, b_arg, bindings, location, cache)? } - try_unify_with_bindings_inner(&function1.return_type, &function2.return_type, bindings, location, cache)?; + // Reverse the arguments when checking return types to preserve + // some subtyping relations with mutable & immutable references. + try_unify_with_bindings_inner(&function2.return_type, &function1.return_type, bindings, location, cache)?; try_unify_with_bindings_inner(&function1.environment, &function2.environment, bindings, location, cache)?; try_unify_with_bindings_inner(&function1.effects, &function2.effects, bindings, location, cache) }, @@ -672,16 +679,11 @@ pub fn try_unify_with_bindings_inner<'b>( }, // Refs have a hidden lifetime variable we need to unify here - (Ref(shared1, mut1, a_lifetime), Ref(shared2, mut2, _)) => { - if shared1 != shared2 || mut1 != mut2 { - if *shared1 != Sharedness::Polymorphic && *shared2 != Sharedness::Polymorphic { - if *mut1 != Mutability::Polymorphic && *mut2 != Mutability::Polymorphic { - return Err(()); - } - } - } - - try_unify_type_variable_with_bindings(*a_lifetime, t1, t2, bindings, location, cache) + (Ref { sharedness: a_shared, mutability: a_mut, lifetime: a_lifetime }, + Ref { sharedness: b_shared, mutability: b_mut, lifetime: b_lifetime }) => { + try_unify_with_bindings_inner(a_shared, b_shared, bindings, location, cache)?; + try_unify_with_bindings_inner(a_mut, b_mut, bindings, location, cache)?; + try_unify_with_bindings_inner(a_lifetime, b_lifetime, bindings, location, cache) }, // Follow any bindings here for convenience so we don't have to check if a or b @@ -709,6 +711,11 @@ pub fn try_unify_with_bindings_inner<'b>( Ok(()) }, + (Tag(tag1), Tag(tag2)) if tag1 == tag2 => Ok(()), + + // ! <: & + (Tag(TypeTag::Mutable), Tag(TypeTag::Immutable)) => Ok(()), + _ => Err(()), } } @@ -731,6 +738,7 @@ fn bind_struct_fields<'c>( rest1, &TypeVariable(rest1), &TypeVariable(rest2), + true, bindings, location, cache, @@ -742,11 +750,11 @@ fn bind_struct_fields<'c>( } else if new_fields.len() != fields1.len() { // Set 1 := 2 let struct2 = Struct(new_fields, rest2); - try_unify_type_variable_with_bindings(rest1, &TypeVariable(rest1), &struct2, bindings, location, cache)?; + try_unify_type_variable_with_bindings(rest1, &TypeVariable(rest1), &struct2, true, bindings, location, cache)?; } else if new_fields.len() != fields2.len() { // Set 2 := 1 let struct1 = Struct(new_fields, rest1); - try_unify_type_variable_with_bindings(rest2, &TypeVariable(rest2), &struct1, bindings, location, cache)?; + try_unify_type_variable_with_bindings(rest2, &TypeVariable(rest2), &struct1, false, bindings, location, cache)?; } Ok(()) @@ -832,7 +840,7 @@ fn get_fields( } }, TypeApplication(constructor, args) => match follow_bindings_in_cache_and_map(constructor, bindings, cache) { - Ref(..) => get_fields(&args[0], &[], bindings, cache), + Ref { .. } => get_fields(&args[0], &[], bindings, cache), other => get_fields(&other, args, bindings, cache), }, Struct(fields, rest) => match &cache.type_bindings[rest.0] { @@ -850,11 +858,19 @@ fn get_fields( /// Unify a single type variable (id arising from the type a) with an expected type b. /// Follows the given TypeBindings in bindings and the cache if a is Bound. fn try_unify_type_variable_with_bindings<'c>( - id: TypeVariableId, a: &Type, b: &Type, bindings: &mut UnificationBindings, location: Location<'c>, + id: TypeVariableId, a: &Type, b: &Type, + typevar_on_lhs: bool, + bindings: &mut UnificationBindings, location: Location<'c>, cache: &mut ModuleCache<'c>, ) -> Result<(), ()> { match find_binding(id, bindings, cache) { - Bound(a) => try_unify_with_bindings_inner(&a, b, bindings, location, cache), + Bound(a) => { + if typevar_on_lhs { + try_unify_with_bindings_inner(&a, b, bindings, location, cache) + } else { + try_unify_with_bindings_inner(b, &a, bindings, location, cache) + } + } Unbound(a_level, _a_kind) => { // Create binding for boundTy that is currently empty. // Ensure not to create recursive bindings to the same variable @@ -893,37 +909,37 @@ pub fn try_unify_with_bindings<'b>( /// set of type bindings, and returning all the newly-created bindings on success, /// or the unification error message on error. pub fn try_unify<'c>( - t1: &Type, t2: &Type, location: Location<'c>, cache: &mut ModuleCache<'c>, error_kind: TypeErrorKind, + actual: &Type, expected: &Type, location: Location<'c>, cache: &mut ModuleCache<'c>, error_kind: TypeErrorKind, ) -> UnificationResult<'c> { let mut bindings = UnificationBindings::empty(); - try_unify_with_bindings(t1, t2, &mut bindings, location, cache, error_kind).map(|()| bindings) + try_unify_with_bindings(actual, expected, &mut bindings, location, cache, error_kind).map(|()| bindings) } /// Try to unify all the given type, with the given bindings in scope. /// Will add new bindings to the given TypeBindings and return them all on success. pub fn try_unify_all_with_bindings<'c>( - vec1: &[Type], vec2: &[Type], mut bindings: UnificationBindings, location: Location<'c>, + actual: &[Type], expected: &[Type], mut bindings: UnificationBindings, location: Location<'c>, cache: &mut ModuleCache<'c>, error_kind: TypeErrorKind, ) -> UnificationResult<'c> { - if vec1.len() != vec2.len() { + if actual.len() != expected.len() { // This bad error message is the reason this function isn't used within // try_unify_with_bindings! We'd need access to the full type to give better // errors like the other function does. - let vec1 = fmap(vec1, |typ| typ.display(cache).to_string()); - let vec2 = fmap(vec2, |typ| typ.display(cache).to_string()); + let vec1 = fmap(actual, |typ| typ.display(cache).to_string()); + let vec2 = fmap(expected, |typ| typ.display(cache).to_string()); return Err(Diagnostic::new(location, D::TypeLengthMismatch(vec1, vec2))); } - for (t1, t2) in vec1.iter().zip(vec2.iter()) { - try_unify_with_bindings(t1, t2, &mut bindings, location, cache, error_kind.clone())?; + for (actual, expected) in actual.iter().zip(expected.iter()) { + try_unify_with_bindings(actual, expected, &mut bindings, location, cache, error_kind.clone())?; } Ok(bindings) } /// Unifies the two given types, remembering the unification results in the cache. /// If this operation fails, a user-facing error message is emitted. -pub fn unify<'c>(t1: &Type, t2: &Type, location: Location<'c>, cache: &mut ModuleCache<'c>, error_kind: TypeErrorKind) { - perform_bindings_or_push_error(try_unify(t1, t2, location, cache, error_kind), cache); +pub fn unify<'c>(actual: &Type, expected: &Type, location: Location<'c>, cache: &mut ModuleCache<'c>, error_kind: TypeErrorKind) { + perform_bindings_or_push_error(try_unify(actual, expected, location, cache, error_kind), cache); } /// Helper for committing to the results of try_unify. @@ -958,6 +974,7 @@ pub fn find_all_typevars(typ: &Type, polymorphic_only: bool, cache: &ModuleCache match typ { Primitive(_) => vec![], UserDefined(_) => vec![], + Tag(_) => vec![], TypeVariable(id) => find_typevars_in_typevar_binding(*id, polymorphic_only, cache), Function(function) => { let mut type_variables = vec![]; @@ -976,7 +993,12 @@ pub fn find_all_typevars(typ: &Type, polymorphic_only: bool, cache: &ModuleCache } type_variables }, - Ref(_, _, lifetime) => find_typevars_in_typevar_binding(*lifetime, polymorphic_only, cache), + Ref { sharedness, mutability, lifetime } => { + let mut type_variables = find_all_typevars(mutability, polymorphic_only, cache); + type_variables.append(&mut find_all_typevars(sharedness, polymorphic_only, cache)); + type_variables.append(&mut find_all_typevars(lifetime, polymorphic_only, cache)); + type_variables + } Struct(fields, id) => match &cache.type_bindings[id.0] { Bound(t) => find_all_typevars(t, polymorphic_only, cache), Unbound(..) => { @@ -1690,10 +1712,8 @@ impl<'a> Inferable<'a> for ast::Definition<'a> { // t, traits let mut result = infer(self.expr.as_mut(), cache); if self.mutable { - let lifetime = next_type_variable_id(cache); - let shared = Sharedness::Polymorphic; - let mutability = Mutability::Mutable; - result.typ = Type::TypeApplication(Box::new(Type::Ref(shared, mutability, lifetime)), vec![result.typ]); + let ref_type = mut_polymorphically_shared_ref(cache); + result.typ = Type::TypeApplication(Box::new(ref_type), vec![result.typ]); } // The rhs of a Definition must be inferred at a greater LetBindingLevel than @@ -1959,8 +1979,7 @@ impl<'a> Inferable<'a> for ast::Assignment<'a> { let mut rhs = infer(self.rhs.as_mut(), cache); result.combine(&mut rhs, cache); - let lifetime = next_type_variable_id(cache); - let mut_ref = Type::Ref(Sharedness::Polymorphic, Mutability::Mutable, lifetime); + let mut_ref = mut_polymorphically_shared_ref(cache); let mutref = Type::TypeApplication(Box::new(mut_ref), vec![rhs.typ.clone()]); match try_unify(&result.typ, &mutref, self.location, cache, TE::NeverShown) { @@ -1972,13 +1991,19 @@ impl<'a> Inferable<'a> for ast::Assignment<'a> { } } +fn mut_polymorphically_shared_ref(cache: &mut ModuleCache) -> Type { + let mutability = Box::new(Type::Tag(TypeTag::Mutable)); + let sharedness = Box::new(next_type_variable(cache)); + let lifetime = Box::new(next_type_variable(cache)); + Type::Ref { mutability, sharedness, lifetime } +} + fn issue_assignment_error<'c>( lhs: &Type, lhs_loc: Location<'c>, rhs: &Type, location: Location<'c>, cache: &mut ModuleCache<'c>, ) { // Try to offer a more specific error message - let lifetime = next_type_variable_id(cache); let var = next_type_variable(cache); - let mutref = Type::Ref(Sharedness::Polymorphic, Mutability::Mutable, lifetime); + let mutref = mut_polymorphically_shared_ref(cache); let mutref = Type::TypeApplication(Box::new(mutref), vec![var]); if let Err(msg) = try_unify(&mutref, lhs, lhs_loc, cache, TE::AssignToNonMutRef) { diff --git a/src/types/typeprinter.rs b/src/types/typeprinter.rs index 976aeafe..4f785790 100644 --- a/src/types/typeprinter.rs +++ b/src/types/typeprinter.rs @@ -4,7 +4,6 @@ //! types/traits are displayed via `type.display(cache)` rather than directly having //! a Display impl. use crate::cache::{ModuleCache, TraitInfoId}; -use crate::parser::ast::{Mutability, Sharedness}; use crate::types::traits::{ConstraintSignature, ConstraintSignaturePrinter, RequiredTrait, TraitConstraintId}; use crate::types::typechecker::find_all_typevars; use crate::types::{FunctionType, PrimitiveType, Type, TypeBinding, TypeInfoId, TypeVariableId}; @@ -17,6 +16,7 @@ use colored::*; use super::effects::EffectSet; use super::GeneralizedType; +use super::typechecker::follow_bindings_in_cache; /// Wrapper containing the information needed to print out a type pub struct TypePrinter<'a, 'b> { @@ -165,9 +165,10 @@ impl<'a, 'b> TypePrinter<'a, 'b> { Type::TypeVariable(id) => self.fmt_type_variable(*id, f), Type::UserDefined(id) => self.fmt_user_defined_type(*id, f), Type::TypeApplication(constructor, args) => self.fmt_type_application(constructor, args, f), - Type::Ref(shared, mutable, lifetime) => self.fmt_ref(*shared, *mutable, *lifetime, f), + Type::Ref { sharedness, mutability, lifetime } => self.fmt_ref(sharedness, mutability, lifetime, f), Type::Struct(fields, rest) => self.fmt_struct(fields, *rest, f), Type::Effects(effects) => self.fmt_effects(effects, f), + Type::Tag(tag) => write!(f, "{tag}"), } } @@ -278,26 +279,34 @@ impl<'a, 'b> TypePrinter<'a, 'b> { } fn fmt_ref( - &self, shared: Sharedness, mutable: Mutability, lifetime: TypeVariableId, f: &mut Formatter, + &self, shared: &Type, mutable: &Type, lifetime: &Type, f: &mut Formatter, ) -> std::fmt::Result { - match &self.cache.type_bindings[lifetime.0] { - TypeBinding::Bound(typ) => self.fmt_type(typ, f), - TypeBinding::Unbound(..) => { - let shared = shared.to_string(); - let mutable = mutable.to_string(); - let space = if shared.is_empty() { "" } else { " " }; + let mutable = follow_bindings_in_cache(mutable, self.cache); + let shared = follow_bindings_in_cache(shared, self.cache); + let parenthesize = matches!(shared, Type::Tag(_)) || self.debug; - write!(f, "{}{}{}{}", "&".blue(), shared.blue(), space, mutable.blue())?; + if parenthesize { + write!(f, "(")?; + } - if self.debug { - match self.typevar_names.get(&lifetime) { - Some(name) => write!(f, "{{{}}}", name)?, - None => write!(f, "{{?{}}}", lifetime.0)?, - } - } - Ok(()) - }, + match mutable { + Type::Tag(tag) => write!(f, "{tag}")?, + _ => write!(f, "?")?, + } + + if let Type::Tag(tag) = shared { + write!(f, "{tag}")?; + } + + if self.debug { + write!(f, " ")?; + self.fmt_type(lifetime, f)?; } + + if parenthesize { + write!(f, ")")?; + } + Ok(()) } fn fmt_forall(&self, typevars: &[TypeVariableId], typ: &Type, f: &mut Formatter) -> std::fmt::Result {