Skip to content

Commit

Permalink
Reuse addEvidenceParams logic, but no aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Nov 14, 2024
1 parent ec6d7ef commit dfa9240
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 76 deletions.
89 changes: 69 additions & 20 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import parsing.Parsers

import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe
import dotty.tools.dotc.quoted.QuoteUtils.treeOwner

object desugar {
import untpd.*
Expand Down Expand Up @@ -52,8 +53,12 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** An attachment key to indicate that a DefDef is a poly function apply
* method definition.
/** When first desugaring a PolyFunction, this attachment is added to the
* PolyFunction `apply` method with an empty list value.
*
* Afterwards, the attachment is added to poly function type trees, with the
* list of their context bounds.
* //TODO(kπ) see if it has to be updated
*/
val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey()

Expand Down Expand Up @@ -497,9 +502,9 @@ object desugar {
case Ident(name: TermName) => boundNames.contains(name)
case _ => false

def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
def recur(mparamss: List[ParamClause]): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
params :: mparamss
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
Expand All @@ -508,30 +513,71 @@ object desugar {
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
else params
(normParams ++ mparams) :: Nil
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
mparams :: recur(mparamss1)
val (fst, snd) = recur(mparamss1)
(mparams :: fst) -> snd
case Nil =>
params :: Nil

def pushDownEvidenceParams(tree: Tree): Tree = tree match
case Function(params, body) =>
cpy.Function(tree)(params, pushDownEvidenceParams(body))
case Block(stats, expr) =>
cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
case tree =>
Nil -> (params :: Nil)

// def pushDownEvidenceParams(tree: Tree): Tree = tree match
// case Function(mparams, body) if mparams.collect { case v: ValDef => v }.exists(referencesBoundName) =>
// ctxFunctionWithParams(tree)
// case Function(mparams, body) =>
// cpy.Function(tree)(mparams, pushDownEvidenceParams(body))
// case Block(stats, expr) =>
// cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
// case tree =>
// ctxFunctionWithParams(tree)

// def ctxFunctionWithParams(tree: Tree): Tree =
// val paramTpts = params.map(_.tpt)
// val paramNames = params.map(_.name)
// val paramsErased = params.map(_.mods.flags.is(Erased))
// Function(params, tree).withSpan(tree.span).withAttachmentsFrom(tree)

def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case head :: rest =>
Function(head, functionsOf(rest, rhs))

if meth.hasAttachment(PolyFunctionApply) then
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
else
cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
println(i"${recur(meth.paramss)}")
recur(meth.paramss) match
case (paramsFst, Nil) =>
cpy.DefDef(meth)(paramss = paramsFst)
case (paramsFst, paramsSnd) =>
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))

// if ctx.mode.is(Mode.Type) then
// meth.removeAttachment(PolyFunctionApply)
// // should be kept on meth to see the current param types?
// meth.tpt.putAttachment(PolyFunctionApply, params)
// val newParamss = recur(meth.paramss)
// println(i"added PolyFunctionApply to ${meth.name}.tpt: ${meth.tpt} with $params")
// println(i"new paramss: $newParamss")
// meth
// else
// val newParamss = recur(meth.paramss)
// println(i"added PolyFunctionApply to ${meth.name} with $params")
// println(i"new paramss: $newParamss")
// val DefDef(_, mparamss, _ , _) = meth: @unchecked
// val tparams :: ValDefs(vparams) :: Nil = mparamss: @unchecked
// if vparams.exists(referencesBoundName) then
// cpy.DefDef(meth)(paramss = tparams :: params :: Nil, rhs = Function(vparams, meth.rhs))
// else
// cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
else
cpy.DefDef(meth)(paramss = recur(meth.paramss))
val (paramsFst, paramsSnd) = recur(meth.paramss)
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
end addEvidenceParams

/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
Expand Down Expand Up @@ -1265,17 +1311,20 @@ object desugar {
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

vparams.foreach(p => println(i" $p, ${p.mods.flags.flagsString}"))
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
.withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty))
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
.withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty))
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
101 changes: 51 additions & 50 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3592,61 +3592,62 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
}

/** Push down the deferred evidence parameters up until the result type is not
* a method type, poly type or a function type
*/
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
case tpe: MethodType =>
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: PolyType =>
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: RefinedType =>
tpe.derivedRefinedType(
pushDownDeferredEvidenceParams(tpe.parent, params, span),
tpe.refinedName,
pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
)
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
case tpe =>
val paramNames = params.map(_.name)
val paramTpts = params.map(_.tpt)
val paramsErased = params.map(_.mods.flags.is(Erased))
val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
typed(ctxFunction).tpe
}

/** If the tree has a `PolyFunctionApply` attachment, add the deferred
* evidence parameters as the last argument list before the result type. This
* follows aliases, so the following two types will be expanded to (up to the
* context bound encoding):
* type CmpWeak[X] = X => Boolean
* type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
* ===>
* type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=>
* Ord[X] => Boolean
*/
private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
tree.getAttachment(desugar.PolyFunctionApply) match
case Some(params) if params.nonEmpty =>
tree.removeAttachment(desugar.PolyFunctionApply)
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
TypeTree(tpe).withSpan(tree.span) -> tpe
case _ => tree -> pt
}
// /** Push down the deferred evidence parameters up until the result type is not
// * a method type, poly type or a function type
// */
// private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type =
// tpe.dealias match {
// case tpe if tpe.baseClasses.contains(defn.PolyFunctionClass) =>
// attachEvidenceParams(tpe, params, span)
// case tpe: MethodType =>
// tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
// case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
// tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
// case tpe =>
// attachEvidenceParams(tpe, params, span)
// }

// /** (params) ?=> tpe */
// private def attachEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type =
// val paramNames = params.map(_.name)
// val paramTpts = params.map(_.tpt)
// val paramsErased = params.map(_.mods.flags.is(Erased))
// val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
// typed(ctxFunction).tpe

// /** If the tree has a `PolyFunctionApply` attachment, add the deferred
// * evidence parameters as the last argument list before the result type or a next poly type.
// * This follows aliases, so the following two types will be expanded to (up to the
// * context bound encoding):
// * type CmpWeak[X] = X => Boolean
// * type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
// * ===>
// * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=>
// * Ord[X] => Boolean
// */
// private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
// tree.getAttachment(desugar.PolyFunctionApply) match
// case Some(params) if params.nonEmpty =>
// tree.putAttachment(desugar.PolyFunctionApply, Nil)
// val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
// TypeTree(tpe).withSpan(tree.span) -> tpe
// case Some(params) =>
// tree -> pt
// case _ => tree -> pt
// }

/** Interpolate and simplify the type of the given tree. */
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree1.isDef // ... unless tree is a definition
// val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree.isDef // ... unless tree is a definition
then
interpolateTypeVars(tree1, pt1, locked)
val simplified = tree1.tpe.simplified
if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743
interpolateTypeVars(tree, pt, locked)
val simplified = tree.tpe.simplified
if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743
tree.overwriteType(simplified)
tree1
tree

protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
Expand Down
33 changes: 27 additions & 6 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import scala.language.future

trait Ord[X]:
def compare(x: X, y: X): Int
type T

trait Show[X]:
def show(x: X): String
Expand All @@ -11,6 +12,8 @@ val less0: [X: Ord] => (X, X) => Boolean = ???

val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

type PolyTest1 = [X] => X => Ord[X] ?=> Boolean

val less1_type_test: [X: Ord] => (X, X) => Boolean =
[X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

Expand All @@ -34,12 +37,12 @@ val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.comp
// type Comparer2 = [X: Ord] => Cmp[X]
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

type CmpWeak[X] = X => Boolean
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
val less4_0: [X: Ord] => X => X => Boolean =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
val less4_1: Comparer2Weak =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
// type CmpWeak[X] = X => Boolean
// type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
// val less4_0: [X: Ord] => X => X => Boolean =
// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
// val less4_1: Comparer2Weak =
// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0

val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

Expand All @@ -65,3 +68,21 @@ val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y)

val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean =
[X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0

type CmpNested = [X: Ord] => X => [Y: Ord] => Y => Boolean
val less10: CmpNested = [X: Ord] => (x: X) => [Y: Ord] => (y: Y) => true
val less10Explicit: CmpNested = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> true

// type CmpAlias[X] = X => Boolean
// type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y]

// val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true
// val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true

val notationalExample: [X: Ord] => X => [Y: Ord] => Y => Int =
[X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> 1

val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x
type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean
type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean
val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true

0 comments on commit dfa9240

Please sign in to comment.