Skip to content

Commit

Permalink
Make the expandion of context bounds for poly types slightly more ele…
Browse files Browse the repository at this point in the history
…gant
  • Loading branch information
KacperFKorban committed Oct 4, 2024
1 parent 42d914e commit 980213b
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 33 deletions.
55 changes: 30 additions & 25 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,7 @@ object desugar {
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)

if meth.hasAttachment(PolyFunctionApply) then
meth.removeAttachment(PolyFunctionApply)
// (kπ): deffer this until we can type the result?
// meth.removeAttachment(PolyFunctionApply)
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
else
Expand Down Expand Up @@ -1238,29 +1237,35 @@ object desugar {
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,8 @@ object Parsers {
getFunction(body) match
case Some(f) =>
PolyFunction(tparams, body)
case None if tparams.exists(_.rhs.isInstanceOf[ContextBounds]) =>
PolyFunction(tparams, body)
case None =>
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
Ident(nme.ERROR.toTypeName)
Expand Down
51 changes: 44 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3590,14 +3590,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
case tpe: MethodType =>
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: PolyType =>
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: RefinedType =>
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
tpe.derivedRefinedType(
pushDownDeferredEvidenceParams(tpe.parent, params, span),
tpe.refinedName,
pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
)
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
case tpe =>
val paramNames = params.map(_.name)
val paramTpts = params.map(_.tpt)
Expand All @@ -3606,18 +3609,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typed(ctxFunction).tpe
}

private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match {
case tpe: MethodType =>
tpe.paramNames -> tpe.paramInfos
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
extractTopMethodTermParams(tpe.refinedInfo)
case _ =>
Nil -> Nil
}

private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match {
case tpe: MethodType =>
tpe.resultType
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
case tpe: AppliedType if defn.isFunctionType(tpe) =>
tpe.args.last
case _ =>
tpe
}

private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match {
case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 =>
val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam))
val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe))
val nestedCtx = ctx.fresh.setNewTyperState()
typed(newDefDef)(using nestedCtx)
case _ => tree
}

private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
tree.getAttachment(desugar.PolyFunctionApply) match
case Some(params) if params.nonEmpty =>
tree.removeAttachment(desugar.PolyFunctionApply)
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
TypeTree(tpe).withSpan(tree.span) -> tpe
// case Some(params) if params.isEmpty =>
// println(s"tree: $tree")
// healToPolyFunctionType(tree) -> pt
case _ => tree -> pt
}

/** Interpolate and simplify the type of the given tree. */
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree1.isDef // ... unless tree is a definition
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
val less4_0: [X: Ord] => X => X => Boolean =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
val less4: Comparer2Weak =
val less4_1: Comparer2Weak =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0

val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
Expand Down

0 comments on commit 980213b

Please sign in to comment.