Skip to content

Commit

Permalink
Add GADT symbols when typing typing-ahead lambda bodies (scala#19644)
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand authored Feb 22, 2024
2 parents 119bc33 + 2c81588 commit 98efdab
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 9 deletions.
23 changes: 14 additions & 9 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1738,8 +1738,9 @@ class Namer { typer: Typer =>
val tpe = (paramss: @unchecked) match
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
val rhsCtx = prepareRhsCtx(ctx.fresh, paramss)
if (isFullyDefined(tpe, ForceDegree.none)) tpe
else typedAheadExpr(mdef.rhs, tpe).tpe
else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe

case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
mdef match {
Expand Down Expand Up @@ -1937,14 +1938,7 @@ class Namer { typer: Typer =>
var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod)
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
if (typeParams.nonEmpty) {
// we'll be typing an expression from a polymorphic definition's body,
// so we must allow constraining its type parameters
// compare with typedDefDef, see tests/pos/gadt-inference.scala
rhsCtx.setFreshGADTBounds
rhsCtx.gadtState.addToConstraint(typeParams)
}
rhsCtx = prepareRhsCtx(rhsCtx, paramss)

def typedAheadRhs(pt: Type) =
PrepareInlineable.dropInlineIfError(sym,
Expand Down Expand Up @@ -1989,4 +1983,15 @@ class Namer { typer: Typer =>
lhsType orElse WildcardType
}
end inferredResultType

/** Prepare a GADT-aware context used to type the RHS of a ValOrDefDef. */
def prepareRhsCtx(rhsCtx: FreshContext, paramss: List[List[Symbol]])(using Context): FreshContext =
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
if typeParams.nonEmpty then
// we'll be typing an expression from a polymorphic definition's body,
// so we must allow constraining its type parameters
// compare with typedDefDef, see tests/pos/gadt-inference.scala
rhsCtx.setFreshGADTBounds
rhsCtx.gadtState.addToConstraint(typeParams)
rhsCtx
}
23 changes: 23 additions & 0 deletions tests/pos/i19570.min1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
enum Op[A]:
case Dup[T]() extends Op[(T, T)]

def foo[R](f: [A] => Op[A] => R): R = ???

def test =
foo([A] => (o: Op[A]) => o match
case o: Op.Dup[u] =>
summon[A =:= (u, u)] // Error: Cannot prove that A =:= (u, u)
()
)
foo[Unit]([A] => (o: Op[A]) => o match
case o: Op.Dup[u] =>
summon[A =:= (u, u)] // Ok
()
)
foo({
val f1 = [B] => (o: Op[B]) => o match
case o: Op.Dup[u] =>
summon[B =:= (u, u)] // Also ok
()
f1
})
24 changes: 24 additions & 0 deletions tests/pos/i19570.min2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
sealed trait Op[A, B] { def giveA: A; def giveB: B }
final case class Dup[T](x: T) extends Op[T, (T, T)] { def giveA: T = x; def giveB: (T, T) = (x, x) }

class Test:
def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = ???

def m1: Unit =
foo([A, B] => (o: Op[A, B]) => o match
case o: Dup[t] =>
var a1: t = o.giveA
var a2: A = o.giveA
a1 = a2
a2 = a1

var b1: (t, t) = o.giveB
var b2: B = o.giveB
b1 = b2
b2 = b1

summon[A =:= t] // ERROR: Cannot prove that A =:= t.
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).

()
)
14 changes: 14 additions & 0 deletions tests/pos/i19570.orig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
enum Op[A, B]:
case Dup[T]() extends Op[T, (T, T)]

def foo[R](f: [A, B] => (o: Op[A, B]) => R): R =
f(Op.Dup())

def test =
foo([A, B] => (o: Op[A, B]) => {
o match
case o: Op.Dup[t] =>
summon[A =:= t] // ERROR: Cannot prove that A =:= t.
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).
42
})

0 comments on commit 98efdab

Please sign in to comment.