Skip to content

Commit

Permalink
More cleanup of poly context bound desugaring
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Nov 14, 2024
1 parent 7755e3b commit 24e3fa0
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 59 deletions.
133 changes: 77 additions & 56 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ object desugar {
* def f$default$2[T](x: Int) = x + "m"
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])

/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
* get added to a buffer.
Expand Down Expand Up @@ -309,10 +309,8 @@ object desugar {
tdef1
end desugarContextBounds

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
val evidenceParamBuf = mutable.ListBuffer[ValDef]()

var seenContextBounds: Int = 0
def freshName(unused: Tree) =
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
Expand All @@ -322,7 +320,7 @@ object desugar {
// parameters of the method since shadowing does not affect
// implicit resolution in Scala 3.

val paramssNoContextBounds =
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
val iflag = paramss.lastOption.flatMap(_.headOption) match
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
param.mods.flags & GivenOrImplicit
Expand All @@ -334,16 +332,29 @@ object desugar {
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
}(identity)

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList
)
meth match
case meth @ DefDef(_, paramss, tpt, rhs) =>
val newParamss = paramssNoContextBounds(paramss)
rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = newParamss
),
evidenceParamBuf.toList
)
case meth @ PolyFunction(tparams, fun) =>
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
val params = evidenceParamBuf.toList
val boundNames = getBoundNames(params, newParamss)
val recur = fitEvidenceParams(params, nme.apply, boundNames)
val (paramsFst, paramsSnd) = recur(newParamss)
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -471,6 +482,55 @@ object desugar {
case _ =>
(Nil, tree)

private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => names.contains(name)
case _ => false

/** Fit evidence `params` into the `mparamss` parameter lists */
private def fitEvidenceParams(params: List[ValDef], methName: Name, boundNames: Set[TermName])(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
else params
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
(mparams :: fst) -> snd
case Nil =>
Nil -> (params :: Nil)

/** Create a chain of possibly contextual functions from the parameter lists */
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = head.map(_.tpt)
val paramNames = head.map(_.name)
val paramsErased = head.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case TypeDefs(head) :: rest =>
PolyFunction(head, functionsOf(rest, rhs))
case _ =>
assert(false, i"unexpected paramss $paramss")
EmptyTree

private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
boundNames

/** Add all evidence parameters in `params` as implicit parameters to `meth`.
* The position of the added parameters is determined as follows:
*
Expand All @@ -485,48 +545,9 @@ object desugar {
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
if params.isEmpty then return meth

var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- meth.paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
val boundNames = getBoundNames(params, meth.paramss)

def referencesBoundName(vdef: ValDef): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => boundNames.contains(name)
case _ => false

def recur(mparamss: List[ParamClause]): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
else params
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
val (fst, snd) = recur(mparamss1)
(mparams :: fst) -> snd
case Nil =>
Nil -> (params :: Nil)

def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case head :: _ =>
assert(false, i"unexpected type parameters when adding evidence parameters to $meth")
EmptyTree
val recur = fitEvidenceParams(params, meth.name, boundNames)

if meth.hasAttachment(PolyFunctionApply) then
meth.removeAttachment(PolyFunctionApply)
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand All @@ -1946,7 +1946,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
typed(desugared, pt)
else
val msg =
Expand All @@ -1955,7 +1954,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
errorTree(EmptyTree, msg, tree.srcPos)
case _ =>
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
typed(desugared, pt)
end typedPolyFunctionValue

Expand Down
5 changes: 5 additions & 0 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,8 @@ val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x
type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean
type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean
val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true
val dependentCmp_1: [X: {Ord as ord}] => ord.T => Boolean = [X: {Ord as ord}] => (x: ord.T) => true

val dependentCmp1: DependentCmp1 = [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true
val dependentCmp1_1: [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean =
[X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true

0 comments on commit 24e3fa0

Please sign in to comment.