diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 7d1134f76d42..a48026063415 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1738,8 +1738,9 @@ class Namer { typer: Typer => val tpe = (paramss: @unchecked) match case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams) case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams) + val rhsCtx = prepareRhsCtx(ctx.fresh, paramss) if (isFullyDefined(tpe, ForceDegree.none)) tpe - else typedAheadExpr(mdef.rhs, tpe).tpe + else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) => mdef match { @@ -1937,14 +1938,7 @@ class Namer { typer: Typer => var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody) if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod) - val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten - if (typeParams.nonEmpty) { - // we'll be typing an expression from a polymorphic definition's body, - // so we must allow constraining its type parameters - // compare with typedDefDef, see tests/pos/gadt-inference.scala - rhsCtx.setFreshGADTBounds - rhsCtx.gadtState.addToConstraint(typeParams) - } + rhsCtx = prepareRhsCtx(rhsCtx, paramss) def typedAheadRhs(pt: Type) = PrepareInlineable.dropInlineIfError(sym, @@ -1989,4 +1983,15 @@ class Namer { typer: Typer => lhsType orElse WildcardType } end inferredResultType + + /** Prepare a GADT-aware context used to type the RHS of a ValOrDefDef. */ + def prepareRhsCtx(rhsCtx: FreshContext, paramss: List[List[Symbol]])(using Context): FreshContext = + val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten + if typeParams.nonEmpty then + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala + rhsCtx.setFreshGADTBounds + rhsCtx.gadtState.addToConstraint(typeParams) + rhsCtx } diff --git a/tests/pos/i19570.min1.scala b/tests/pos/i19570.min1.scala new file mode 100644 index 000000000000..2cbc852641d3 --- /dev/null +++ b/tests/pos/i19570.min1.scala @@ -0,0 +1,23 @@ +enum Op[A]: + case Dup[T]() extends Op[(T, T)] + +def foo[R](f: [A] => Op[A] => R): R = ??? + +def test = + foo([A] => (o: Op[A]) => o match + case o: Op.Dup[u] => + summon[A =:= (u, u)] // Error: Cannot prove that A =:= (u, u) + () + ) + foo[Unit]([A] => (o: Op[A]) => o match + case o: Op.Dup[u] => + summon[A =:= (u, u)] // Ok + () + ) + foo({ + val f1 = [B] => (o: Op[B]) => o match + case o: Op.Dup[u] => + summon[B =:= (u, u)] // Also ok + () + f1 + }) diff --git a/tests/pos/i19570.min2.scala b/tests/pos/i19570.min2.scala new file mode 100644 index 000000000000..b1450d7e2d1a --- /dev/null +++ b/tests/pos/i19570.min2.scala @@ -0,0 +1,24 @@ +sealed trait Op[A, B] { def giveA: A; def giveB: B } +final case class Dup[T](x: T) extends Op[T, (T, T)] { def giveA: T = x; def giveB: (T, T) = (x, x) } + +class Test: + def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = ??? + + def m1: Unit = + foo([A, B] => (o: Op[A, B]) => o match + case o: Dup[t] => + var a1: t = o.giveA + var a2: A = o.giveA + a1 = a2 + a2 = a1 + + var b1: (t, t) = o.giveB + var b2: B = o.giveB + b1 = b2 + b2 = b1 + + summon[A =:= t] // ERROR: Cannot prove that A =:= t. + summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t). + + () + ) diff --git a/tests/pos/i19570.orig.scala b/tests/pos/i19570.orig.scala new file mode 100644 index 000000000000..6e574f52be91 --- /dev/null +++ b/tests/pos/i19570.orig.scala @@ -0,0 +1,14 @@ +enum Op[A, B]: + case Dup[T]() extends Op[T, (T, T)] + +def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = + f(Op.Dup()) + +def test = + foo([A, B] => (o: Op[A, B]) => { + o match + case o: Op.Dup[t] => + summon[A =:= t] // ERROR: Cannot prove that A =:= t. + summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t). + 42 + })