Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honour hard unions in lubbing and param replacing #18680

Merged
merged 3 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ abstract class Constraint extends Showable {
/** The same as this constraint, but with `tv` marked as hard. */
def withHard(tv: TypeVar)(using Context): This

/** Mark toplevel type vars in `tp` as hard. */
def hardenTypeVars(tp: Type)(using Context): This

/** Gives for each instantiated type var that does not yet have its `inst` field
* set, the instance value stored in the constraint. Storing instances in constraints
* is done only in a temporary way for contexts that may be retracted
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,18 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
}
if isRemovable(param.binder) then current = current.remove(param.binder)
current.dropDeps(param)
replacedTypeVar match
case replacedTypeVar: TypeVar if isHard(replacedTypeVar) => current = current.hardenTypeVars(replacement)
case _ =>
current.checkWellFormed()
end replace

def hardenTypeVars(tp: Type)(using Context): OrderingConstraint = tp.dealiasKeepRefiningAnnots match
case tp: TypeVar if contains(tp.origin) => withHard(tp)
case tp: TypeParamRef if contains(tp) => hardenTypeVars(typeVarOfParam(tp))
case tp: AndOrType => hardenTypeVars(tp.tp1).hardenTypeVars(tp.tp2)
case _ => this

def remove(pt: TypeLambda)(using Context): This = {
def removeFromOrdering(po: ParamOrdering) = {
def removeFromBoundss(key: TypeLambda, bndss: Array[List[TypeParamRef]]): Array[List[TypeParamRef]] = {
Expand Down
15 changes: 2 additions & 13 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,17 +501,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
false
}

/** Mark toplevel type vars in `tp2` as hard in the current constraint */
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
case tvar: TypeVar if constraint.contains(tvar.origin) =>
constraint = constraint.withHard(tvar)
case tp2: TypeParamRef if constraint.contains(tp2) =>
hardenTypeVars(constraint.typeVarOfParam(tp2))
case tp2: AndOrType =>
hardenTypeVars(tp2.tp1)
hardenTypeVars(tp2.tp2)
case _ =>

val res = widenOK || joinOK
|| recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
Expand All @@ -534,7 +523,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// is marked so that it converts all soft unions in its lower bound to hard unions
// before it is instantiated. The reason is that the variable's instance type will
// be a supertype of (decomposed and reconstituted) `tp1`.
hardenTypeVars(tp2)
constraint = constraint.hardenTypeVars(tp2)

res

Expand Down Expand Up @@ -2388,7 +2377,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case Atoms.Range(lo2, hi2) =>
if hi1.subsetOf(lo2) then return tp2
if hi2.subsetOf(lo1) then return tp1
if (hi1 & hi2).isEmpty then return orType(tp1, tp2)
if (hi1 & hi2).isEmpty then return orType(tp1, tp2, isSoft = isSoft)
case none =>
case none =>
val t1 = mergeIfSuper(tp1, tp2, canConstrain)
Expand Down
14 changes: 14 additions & 0 deletions tests/pos/i18626.min1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
sealed trait Animal
object Cat extends Animal
object Dog extends Animal

type Mammal = Cat.type | Dog.type

class Test:
def t1 =
val mammals: List[Mammal] = ???
val result = mammals.head
val mammal: Mammal = result // was: Type Mismatch Error:
// Found: (result : Animal)
// Required: Mammal
()
32 changes: 32 additions & 0 deletions tests/pos/i18626.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
trait Random[F1[_]]:
def element[T1](list: Seq[T1]): F1[T1] = ???

trait Monad[F2[_]]:
def map[A1, B1](fa: F2[A1])(f: A1 => B1): F2[B1]

object Monad:
extension [F3[_]: Monad, A3](fa: F3[A3])
def map[B3](f: A3 => B3): F3[B3] = ???

sealed trait Animal
object Cat extends Animal
object Dog extends Animal

type Mammal = Cat.type | Dog.type
val mammals: List[Mammal] = ???

class Work[F4[_]](random: Random[F4])(using mf: Monad[F4]):
def result1: F4[Mammal] =
mf.map(fa = random.element(mammals))(a => a)

def result2: F4[Mammal] = Monad.map(random.element(mammals))(a => a)

import Monad.*

def result3: F4[Mammal] = random
.element(mammals)
.map { a =>
a // was: Type Mismatch Error:
// Found: (a : Animal)
// Required: Cat.type | Dog.type
}