Skip to content

Commit

Permalink
Check bounds in match type case bodies
Browse files Browse the repository at this point in the history
Requires a few things.

In the tests, it requires propagating some bounds, as well as tweaking
how things are matched.

Also, requires a few changes in how type patterns add constraints, with
a fix on type constructors and another guard in widening abstract types.
  • Loading branch information
dwijnand committed May 26, 2023
1 parent 09f5e4c commit a1eb832
Show file tree
Hide file tree
Showing 19 changed files with 95 additions and 57 deletions.
39 changes: 18 additions & 21 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
* in which case the subtyping relationship "heals" the type.
*/
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($pat, $scrut)", gadts) {

def classesMayBeCompatible: Boolean = {
import Flags._
Expand Down Expand Up @@ -231,41 +231,32 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
*/
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
val debug = noPrinter
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: SingletonType => true
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
case tp: TypeProxy => refinementIsInvariant(tp.superType)
case _ => false
}

def widenVariantParams(tp: Type) = tp match {
case tp @ AppliedType(tycon, args) =>
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
)
tp.derivedAppliedType(tycon, args1)
case tp =>
tp
}

val patternCls = patternTp.classSymbol
val scrutineeCls = scrutineeTp.classSymbol

// NOTE: we already know that there is a derives-from relationship in either direction
val upcastPattern =
patternCls.derivesFrom(scrutineeCls)

val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
val pat = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
val scr = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp

val assumeInvariantRefinement =
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)

trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
trace(i"constraining simple pattern type $pat >:< $scr assume=$assumeInvariantRefinement", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
(scr, pat) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) if tyconP.frozen_=:=(tyconS) =>
val saved = state.nn.constraint
val result =
val success =
ctx.gadtState.rollbackGadtUnless {
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
Expand All @@ -277,15 +268,21 @@ trait PatternTypeConstrainer { self: TypeComparer =>
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
var res = true
if variance < 1 then res &&= isSubType(loS, hiP)
if variance > -1 then res &&= isSubType(loP, hiS)
if ctx.mode.is(Mode.Type) then
if variance > -1 then res &&= isSubType(loS, hiP).showing(i"$loS <: $hiP = $result v=$variance argS=$argS argP=$argP", debug)
if variance < 1 then res &&= isSubType(loP, hiS).showing(i"$hiS >: $loP = $result v=$variance argS=$argS argP=$argP", debug)
else
if variance < 1 then res &&= isSubType(loS, hiP).showing(i"$hiP >: $loS = $result v=$variance argP=$argP argS=$argS", debug)
if variance > -1 then res &&= isSubType(loP, hiS).showing(i"$loP <: $hiS = $result v=$variance argP=$argP argS=$argS", debug)
res
else true
}
}
if !result then
if !success then
constraint = saved
result
success
case (scr: TypeRef, _) if ctx.mode.is(Mode.Type) && ctx.gadt.contains(scr.symbol) =>
isSubType(scrutineeTp, patternTp).showing(i"$scrutineeTp <: $patternTp = $result", debug)
case _ =>
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
// Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case info1 @ TypeBounds(lo1, hi1) =>
def compareGADT =
tp1.symbol.onGadtBounds(gbounds1 =>
isSubTypeWhenFrozen(gbounds1.hi, tp2)
(!caseLambda.exists || widenAbstractOKFor(tp2)) && isSubTypeWhenFrozen(gbounds1.hi, tp2)
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
&& (tp2.isAny || GADTusage(tp1.symbol))

Expand Down Expand Up @@ -3117,7 +3117,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
super.typeVarInstance(tvar)
}

def matchCases(scrut: Type, cases: List[Type])(using Context): Type = {
def matchCases(scrut: Type, cases: List[Type])(using Context): Type = trace(i"matchCases($scrut, $cases)") {
// a reference for the type parameters poisoned during matching
// for use during the reduction step
var poisoned: Set[TypeParamRef] = Set.empty
Expand Down Expand Up @@ -3169,7 +3169,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {

val defn.MatchCase(pat, body) = cas1: @unchecked

def matches(canWidenAbstract: Boolean): Boolean =
def matches(canWidenAbstract: Boolean): Boolean = trace(i"matches(canWidenAbstract=$canWidenAbstract)") {
val saved = this.canWidenAbstract
val savedPoisoned = this.poisoned
this.canWidenAbstract = canWidenAbstract
Expand All @@ -3179,8 +3179,9 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
poisoned = this.poisoned
this.poisoned = savedPoisoned
this.canWidenAbstract = saved
}

def redux(canApprox: Boolean): MatchResult =
def redux(canApprox: Boolean): MatchResult = trace(i"redux(canApprox=$canApprox)") {
caseLambda match
case caseLambda: HKTypeLambda =>
val instances = paramInstances(canApprox)(Array.fill(caseLambda.paramNames.length)(NoType), pat)
Expand All @@ -3195,6 +3196,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
MatchResult.Reduced(redux.simplified)
case _ =>
MatchResult.Reduced(body)
}

if caseLambda.exists && matches(canWidenAbstract = false) then
redux(canApprox = true)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ object TypeEval:
case tp: TypeProxy =>
val tp1 = tp.superType
if tp1.isStable then tp1.fixForEvaluation else tp
case tp: AndType =>
// tests/pos/9890.scala
// allow `((0 : Int) & Int) * (3 : Int)` to be folded
val glb = tp.tp1 & tp.tp2
if tp ne glb then glb.fixForEvaluation else tp
case tp => tp

def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match
Expand Down
19 changes: 11 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,16 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
if !tree.symbol.is(Package) then tree
else errorTree(tree, em"${tree.symbol} cannot be used as a type")

private def gadtCtx(tree: CaseDef)(using Context): Context =
tree.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
case None => ctx

override def transform(tree: Tree)(using Context): Tree =
try tree match {
// TODO move CaseDef case lower: keep most probable trees first for performance
case CaseDef(pat, _, _) =>
val gadtCtx =
pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
case None =>
ctx
super.transform(tree)(using gadtCtx)
case tree: CaseDef =>
super.transform(tree)(using gadtCtx(tree))
case tree: Ident =>
if tree.isType then
checkNotPackage(tree)
Expand Down Expand Up @@ -477,7 +477,10 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
case m @ MatchTypeTree(bounds, selector, cases) =>
// Analog to the case above for match types
def transformIgnoringBoundsCheck(x: CaseDef): CaseDef =
withMode(Mode.Pattern)(super.transform(x)).asInstanceOf[CaseDef]
inContext(gadtCtx(x)) {
val pat1 = inMode(Mode.Pattern)(transform(x.pat))
cpy.CaseDef(tree)(pat1, transform(x.guard), transform(x.body))
}
cpy.MatchTypeTree(tree)(
super.transform(bounds),
super.transform(selector),
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ class Namer { typer: Typer =>

override final def typeSig(sym: Symbol): Type =
val tparamSyms = completerTypeParams(sym)(using ictx)
given ctx: Context = nestedCtx.nn
given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds
if tparamSyms.nonEmpty then ctx.gadtState.addToConstraint(tparamSyms)

def abstracted(tp: TypeBounds): TypeBounds =
HKTypeLambda.boundsFromParams(tparamSyms, tp)
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
else report.error(new DuplicateBind(b, cdef), b.srcPos)
if (!ctx.isAfterTyper) {
val bounds = ctx.gadt.fullBounds(sym)
if (bounds != null) sym.info = bounds
if (bounds != null)
val info = if ctx.mode.is(Mode.Type) then bounds match
case TypeBounds(lo, hi) if !lo.isExactlyNothing && hi.isExactlyAny => TypeBounds(defn.NothingType, lo)
case TypeAlias(_) => sym.info
case bounds => bounds
else bounds
sym.info = info
}
b
case t: UnApply if t.symbol.is(Inline) => Inlines.inlinedUnapply(t)
Expand Down Expand Up @@ -1916,6 +1922,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
val pat2 = indexPattern(cdef).transform(pat1)
var body1 = typedType(cdef.body, pt)
if ctx.gadt.isNarrowing then
pat1.putAttachment(InferredGadtConstraints, ctx.gadt)
if !body1.isType then
assert(ctx.reporter.errorsReported)
body1 = TypeTree(errorType(em"<error: not a type>", cdef.srcPos))
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/6570.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object ThisTypeVariant {
}

object ParametricVariant {
type Trick[a] = { type A <: a }
type Trick[a] = Any { type A <: a }
type M[t] = t match { case Trick[a] => N[a] }

trait Root[B] {
Expand Down
6 changes: 6 additions & 0 deletions tests/neg/i13741.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type Init[X <: NonEmptyTuple] <: Tuple = X match
case _ *: EmptyTuple => EmptyTuple
case x *: xs =>
x *: Init[xs] // error

def a: Init[Tuple3[Int, String, Boolean]] = ???
4 changes: 2 additions & 2 deletions tests/neg/i15272.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
case class Head[+NT, +From <: NT, +To <: NT] (from: From, to: To ) extends EdgeN[NT]
case class Cons[+NT, +From <: NT, +ToE <: EdgeN[NT]](from: From, to: ToE) extends EdgeN[NT]
final type InNodesTupleOf[NT, E <: EdgeN[NT]] <: Tuple = E match
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE]
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE] // error
case Head[nt,from ,to] => from *: EmptyTuple
def inNodesTuple[NT,E <: EdgeN[NT]](edge: E): InNodesTupleOf[NT,E] = edge match
case e: Cons[nt,from,toE] => e.from *: inNodesTuple[nt,toE](e.to) // error
case e: Head[nt,from,to] => e.from *: EmptyTuple
end EdgeN
end EdgeN
3 changes: 2 additions & 1 deletion tests/pos/10867.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
object Test {
// e.g inserts[z, (a, b)] =:= ((z, a, b), (a, z, b), (a, b, z))
type inserts[a, as <: Tuple] <: Tuple =
as match
case EmptyTuple => (a *: EmptyTuple) *: EmptyTuple
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t <: Tuple] =>> y *: t]
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t <: Tuple.Union[inserts[a, ys]]] =>> y *: (t & Tuple)]

type inserts2[a] =
[as <: Tuple] =>> inserts[a, as]
Expand Down
4 changes: 2 additions & 2 deletions tests/pos/13633.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object Sums extends App:

type Reverse[A] = ReverseLoop[A, EmptyTuple]

type PlusTri[A, B, C] = (A, B, C) match
type PlusTri[A, B, C] <: (Boolean, Boolean) = (A, B, C) match
case (false, false, false) => (false, false)
case (true, false, false) | (false, true, false) | (false, false, true) => (false, true)
case (true, true, false) | (true, false, true) | (false, true, true) => (true, false)
Expand All @@ -38,7 +38,7 @@ object Sums extends App:
case false => A
case true => Inc[A]

type PlusLoop[A <: Tuple, B <: Tuple, O] <: Tuple = (A, B) match
type PlusLoop[A <: Tuple, B <: Tuple, O <: Boolean] <: Tuple = (A, B) match
case (EmptyTuple, EmptyTuple) =>
O match
case true => (true *: EmptyTuple)
Expand Down
6 changes: 3 additions & 3 deletions tests/pos/9239.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ object ABug:
type Zero = B0 :: Nil
type One = B1 :: Nil

type --[B <: Bin] =
type --[B <: Bin] <: Bin =
B match
case B1 :: d => B0 :: d
case B0 :: B1 :: Nil => B1 :: Nil
case B0 :: d => B1 :: --[d]

type ×[N <: Bin, M <: Bin] =
type ×[N <: Bin, M <: Bin] <: Bin =
(N, M) match
case (Zero, ?) => Zero

type ![N <: Bin] =
type ![N <: Bin] <: Bin =
N match
case Zero => One
case One => One
Expand Down
3 changes: 2 additions & 1 deletion tests/pos/9890.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ object Test {

type TupleMap[Tup <: Tuple, Bound, F[_ <: Bound]] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
case h *: t => F[h] *: TupleMap[t, Bound, F]
case h *: t => h match
case Bound => F[h] *: TupleMap[t, Bound, F]
}
type TupleDedup[Tup <: Tuple, Mask] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
Expand Down
1 change: 1 addition & 0 deletions tests/pos/i15926.contra.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ type MT1[I <: Show[Nothing], N] = I match
case Int => a

val a = summon[MT1[Show[String], Int] =:= String]
def b: MT1[Show[String], Int] = ""
8 changes: 6 additions & 2 deletions tests/pos/i15926.extract.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ final case class Succ[+N <: Nat]() extends Nat

final case class Neg[+N <: Succ[Nat]]()

type Sum[X, Y] = Y match
type Sum[X <: Nat, Y] = Y match
case Zero => X
case Succ[y] => Sum[Succ[X], y]

type IntSum[A, B] = B match
case Neg[b] => IntSumNeg[A, b]

type IntSumNeg[A, B] = A match
case Neg[a] => Neg[Sum[a, B]]
case Neg[a] => Negate[Sum[a, B]]

type Negate[A] = A match
case Zero => Zero
case Succ[_] => Neg[A]

type One = Succ[Zero]
type Two = Succ[One]
Expand Down
8 changes: 6 additions & 2 deletions tests/pos/i15926.min.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ final case class Succ[+N <: Nat]() extends Nat

final case class Neg[+N <: Succ[Nat]]()

type Sum[X, Y] = Y match
type Sum[X <: Nat, Y] <: Nat = Y match
case Zero => X
case Succ[y] => Sum[Succ[X], y]

type IntSum[A, B] = B match
case Neg[b] => A match
case Neg[a] => Neg[Sum[a, b]]
case Neg[a] => Negate[Sum[a, b]]

type Negate[A] = A match
case Zero => Zero
case Succ[_] => Neg[A]

type One = Succ[Zero]
type Two = Succ[One]
Expand Down
12 changes: 8 additions & 4 deletions tests/pos/i15926.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ type NatDif[X <: NatT, Y <: NatT] <: IntT = Y match
type Sum[X <: IntT, Y <: IntT] <: IntT = Y match
case Zero => X
case Minus[y] => X match
case Minus[x] => Minus[NatSum[x, y]]
case _ => NatDif[X, y]
case _ => X match
case Minus[x] => Negate[NatSum[x, y]]
case NatT => NatDif[X, y]
case NatT => X match
case Minus[x] => NatDif[Y, x]
case _ => NatSum[X, Y]
case NatT => NatSum[X, Y]

type Negate[A] = A match
case Zero => Zero
case Succ[_] => Neg[A]
5 changes: 3 additions & 2 deletions tests/pos/i16706.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import scala.deriving.Mirror
import scala.reflect.ClassTag

type TupleUnionLub[T <: Tuple, Lub, Acc <: Lub] <: Lub = T match {
case (h & Lub) *: t => TupleUnionLub[t, Lub, Acc | h]
case h *: t => h match
case Lub => TupleUnionLub[t, Lub, Acc | h]
case EmptyTuple => Acc
}

Expand All @@ -14,4 +15,4 @@ transparent inline given derived[A](
sealed trait Foo
case class FooA(a: Int) extends Foo

val instance = derived[Foo] // error
val instance = derived[Foo] // error
Loading

0 comments on commit a1eb832

Please sign in to comment.