From 980213bea070472b9088a60a7d796cc8a2001f51 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 4 Oct 2024 13:21:40 +0200 Subject: [PATCH] Make the expandion of context bounds for poly types slightly more elegant --- .../src/dotty/tools/dotc/ast/Desugar.scala | 55 ++++++++++--------- .../dotty/tools/dotc/parsing/Parsers.scala | 2 + .../src/dotty/tools/dotc/typer/Typer.scala | 51 ++++++++++++++--- .../contextbounds-for-poly-functions.scala | 2 +- 4 files changed, 77 insertions(+), 33 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 6edf3846dfb3..5b3cc9b04049 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -527,8 +527,7 @@ object desugar { makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) if meth.hasAttachment(PolyFunctionApply) then - meth.removeAttachment(PolyFunctionApply) - // (kπ): deffer this until we can type the result? + // meth.removeAttachment(PolyFunctionApply) if ctx.mode.is(Mode.Type) then cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) else @@ -1238,29 +1237,35 @@ object desugar { /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ - def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = - val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked - val paramFlags = fun match - case fun: FunctionWithMods => - // TODO: make use of this in the desugaring when pureFuns is enabled. - // val isImpure = funFlags.is(Impure) - - // Function flags to be propagated to each parameter in the desugared method type. - val givenFlag = fun.mods.flags.toTermFlags & Given - fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) - case _ => - vparamTypes.map(_ => EmptyFlags) - - val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { - case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) - case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) - }.toList - - RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) - .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, List.empty) - )).withSpan(tree.span) + def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match + case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) => + val paramFlags = fun match + case fun: FunctionWithMods => + // TODO: make use of this in the desugaring when pureFuns is enabled. + // val isImpure = funFlags.is(Impure) + + // Function flags to be propagated to each parameter in the desugared method type. + val givenFlag = fun.mods.flags.toTermFlags & Given + fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) + case _ => + vparamTypes.map(_ => EmptyFlags) + + val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { + case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) + case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + }.toList + + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, List.empty) + )).withSpan(tree.span) + case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) => + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, List.empty) + )).withSpan(tree.span) end makePolyFunctionType /** Invent a name for an anonympus given of type or template `impl`. */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index f51d4ffd3bc7..adef5bd5717b 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1761,6 +1761,8 @@ object Parsers { getFunction(body) match case Some(f) => PolyFunction(tparams, body) + case None if tparams.exists(_.rhs.isInstanceOf[ContextBounds]) => + PolyFunction(tparams, body) case None => syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) Ident(nme.ERROR.toTypeName) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 24d9c7d591e1..137bd8d7be8a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3590,14 +3590,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { case tpe: MethodType => - MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) case tpe: PolyType => - PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) case tpe: RefinedType => - // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement - RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)) + tpe.derivedRefinedType( + pushDownDeferredEvidenceParams(tpe.parent, params, span), + tpe.refinedName, + pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span) + ) case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => - AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) + tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) case tpe => val paramNames = params.map(_.name) val paramTpts = params.map(_.tpt) @@ -3606,18 +3609,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typed(ctxFunction).tpe } - private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { + private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match { + case tpe: MethodType => + tpe.paramNames -> tpe.paramInfos + case tpe: RefinedType if defn.isFunctionType(tpe.parent) => + extractTopMethodTermParams(tpe.refinedInfo) + case _ => + Nil -> Nil + } + + private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match { + case tpe: MethodType => + tpe.resultType + case tpe: RefinedType if defn.isFunctionType(tpe.parent) => + tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo)) + case tpe: AppliedType if defn.isFunctionType(tpe) => + tpe.args.last + case _ => + tpe + } + + private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match { + case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 => + val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe) + val newTpe = removeTopMethodTermParams(defdef.tpt.tpe) + val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam)) + val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe)) + val nestedCtx = ctx.fresh.setNewTyperState() + typed(newDefDef)(using nestedCtx) + case _ => tree + } + + private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { tree.getAttachment(desugar.PolyFunctionApply) match case Some(params) if params.nonEmpty => tree.removeAttachment(desugar.PolyFunctionApply) val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) TypeTree(tpe).withSpan(tree.span) -> tpe + // case Some(params) if params.isEmpty => + // println(s"tree: $tree") + // healToPolyFunctionType(tree) -> pt case _ => tree -> pt } /** Interpolate and simplify the type of the given tree. */ protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = - val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt) + val (tree1, pt1) = addDeferredEvidenceParams(tree, pt) if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied || tree1.isDef // ... unless tree is a definition diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index adaf6c035406..8c7bead36633 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean type Comparer2Weak = [X: Ord] => X => CmpWeak[X] val less4_0: [X: Ord] => X => X => Boolean = [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 -val less4: Comparer2Weak = +val less4_1: Comparer2Weak = [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0