Skip to content

Commit

Permalink
Cleanup context bounds for poly functions implementation, make the im…
Browse files Browse the repository at this point in the history
…plementation consistent with addEvidenceParams
  • Loading branch information
KacperFKorban committed Nov 14, 2024
1 parent dfa9240 commit 7755e3b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 118 deletions.
79 changes: 16 additions & 63 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,9 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** When first desugaring a PolyFunction, this attachment is added to the
* PolyFunction `apply` method with an empty list value.
*
* Afterwards, the attachment is added to poly function type trees, with the
* list of their context bounds.
* //TODO(kπ) see if it has to be updated
/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
*/
val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey()
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
Expand Down Expand Up @@ -520,61 +515,28 @@ object desugar {
case Nil =>
Nil -> (params :: Nil)

// def pushDownEvidenceParams(tree: Tree): Tree = tree match
// case Function(mparams, body) if mparams.collect { case v: ValDef => v }.exists(referencesBoundName) =>
// ctxFunctionWithParams(tree)
// case Function(mparams, body) =>
// cpy.Function(tree)(mparams, pushDownEvidenceParams(body))
// case Block(stats, expr) =>
// cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
// case tree =>
// ctxFunctionWithParams(tree)

// def ctxFunctionWithParams(tree: Tree): Tree =
// val paramTpts = params.map(_.tpt)
// val paramNames = params.map(_.name)
// val paramsErased = params.map(_.mods.flags.is(Erased))
// Function(params, tree).withSpan(tree.span).withAttachmentsFrom(tree)

def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case head :: rest =>
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case head :: _ =>
assert(false, i"unexpected type parameters when adding evidence parameters to $meth")
EmptyTree

if meth.hasAttachment(PolyFunctionApply) then
println(i"${recur(meth.paramss)}")
recur(meth.paramss) match
case (paramsFst, Nil) =>
cpy.DefDef(meth)(paramss = paramsFst)
case (paramsFst, paramsSnd) =>
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))

// if ctx.mode.is(Mode.Type) then
// meth.removeAttachment(PolyFunctionApply)
// // should be kept on meth to see the current param types?
// meth.tpt.putAttachment(PolyFunctionApply, params)
// val newParamss = recur(meth.paramss)
// println(i"added PolyFunctionApply to ${meth.name}.tpt: ${meth.tpt} with $params")
// println(i"new paramss: $newParamss")
// meth
// else
// val newParamss = recur(meth.paramss)
// println(i"added PolyFunctionApply to ${meth.name} with $params")
// println(i"new paramss: $newParamss")
// val DefDef(_, mparamss, _ , _) = meth: @unchecked
// val tparams :: ValDefs(vparams) :: Nil = mparamss: @unchecked
// if vparams.exists(referencesBoundName) then
// cpy.DefDef(meth)(paramss = tparams :: params :: Nil, rhs = Function(vparams, meth.rhs))
// else
// cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
meth.removeAttachment(PolyFunctionApply)
// for PolyFunctions we are limited to a single term param list, so we reuse the recur logic to compute the new parameter lists
// and then we add the other parameter lists as function types to the return type
val (paramsFst, paramsSnd) = recur(meth.paramss)
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
else
val (paramsFst, paramsSnd) = recur(meth.paramss)
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
Expand Down Expand Up @@ -1293,7 +1255,7 @@ object desugar {
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
case fun: FunctionWithMods =>
Expand All @@ -1311,20 +1273,11 @@ object desugar {
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

vparams.foreach(p => println(i" $p, ${p.mods.flags.flagsString}"))
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
.withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty))
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, List.empty)
.withAttachment(PolyFunctionApply, ())
)).withSpan(tree.span)
.withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty))
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
45 changes: 0 additions & 45 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3592,53 +3592,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
}

// /** Push down the deferred evidence parameters up until the result type is not
// * a method type, poly type or a function type
// */
// private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type =
// tpe.dealias match {
// case tpe if tpe.baseClasses.contains(defn.PolyFunctionClass) =>
// attachEvidenceParams(tpe, params, span)
// case tpe: MethodType =>
// tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
// case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
// tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
// case tpe =>
// attachEvidenceParams(tpe, params, span)
// }

// /** (params) ?=> tpe */
// private def attachEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type =
// val paramNames = params.map(_.name)
// val paramTpts = params.map(_.tpt)
// val paramsErased = params.map(_.mods.flags.is(Erased))
// val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
// typed(ctxFunction).tpe

// /** If the tree has a `PolyFunctionApply` attachment, add the deferred
// * evidence parameters as the last argument list before the result type or a next poly type.
// * This follows aliases, so the following two types will be expanded to (up to the
// * context bound encoding):
// * type CmpWeak[X] = X => Boolean
// * type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
// * ===>
// * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=>
// * Ord[X] => Boolean
// */
// private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
// tree.getAttachment(desugar.PolyFunctionApply) match
// case Some(params) if params.nonEmpty =>
// tree.putAttachment(desugar.PolyFunctionApply, Nil)
// val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
// TypeTree(tpe).withSpan(tree.span) -> tpe
// case Some(params) =>
// tree -> pt
// case _ => tree -> pt
// }

/** Interpolate and simplify the type of the given tree. */
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
// val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree.isDef // ... unless tree is a definition
Expand Down
20 changes: 10 additions & 10 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.comp
// type Comparer2 = [X: Ord] => Cmp[X]
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

// type CmpWeak[X] = X => Boolean
// type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
// val less4_0: [X: Ord] => X => X => Boolean =
// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
// val less4_1: Comparer2Weak =
// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
type CmpWeak[X] = X => Boolean
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
val less4_0: [X: Ord] => X => X => Boolean =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
val less4_1: Comparer2Weak =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0

val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

Expand Down Expand Up @@ -73,11 +73,11 @@ type CmpNested = [X: Ord] => X => [Y: Ord] => Y => Boolean
val less10: CmpNested = [X: Ord] => (x: X) => [Y: Ord] => (y: Y) => true
val less10Explicit: CmpNested = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> true

// type CmpAlias[X] = X => Boolean
// type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y]
type CmpAlias[X] = X => Boolean
type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y]

// val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true
// val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true
val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true
val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true

val notationalExample: [X: Ord] => X => [Y: Ord] => Y => Int =
[X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> 1
Expand Down

0 comments on commit 7755e3b

Please sign in to comment.