From a76470f9f34c878e759a20a4e2c76027a1ec6c22 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 10:18:35 +0200 Subject: [PATCH 01/19] Implement basic version of desugaring context bounds for poly functions --- .../src/dotty/tools/dotc/ast/Desugar.scala | 27 +++++++++++++++++++ .../dotty/tools/dotc/parsing/Parsers.scala | 4 +-- .../src/dotty/tools/dotc/typer/Typer.scala | 5 ++-- .../contextbounds-for-poly-functions.scala | 15 +++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 tests/pos/contextbounds-for-poly-functions.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e66c71731b4f..901fbd1bb601 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1221,6 +1221,33 @@ object desugar { case _ => body cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction] + /** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R + * Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R + */ + def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction = + val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked + val newTParams = tparams.map { + case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => + TypeDef(name, ContextBounds(bounds, List.empty)) + } + var idx = -1 + val collecedContextBounds = tparams.collect { + case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => + // TOOD(kπ) Should we handle non empty normal bounds here? + name -> ctxBounds + }.flatMap { case (name, ctxBounds) => + ctxBounds.map { ctxBound => + idx = idx + 1 + makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) + } + } + val contextFunctionResult = + if collecedContextBounds.isEmpty then + fun + else + Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span) + PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) + /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 47391a4114cf..7a5facf38b67 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -68,7 +68,7 @@ object Parsers { def acceptsVariance = this == Class || this == CaseClass || this == Hk def acceptsCtxBounds = - !(this == Type || this == Hk) + !(this == Hk) def acceptsWildcard = this == Type || this == Hk @@ -3429,7 +3429,7 @@ object Parsers { * * TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’ * TypTypeParam ::= {Annotation} - * (id | ‘_’) [HkTypeParamClause] TypeBounds + * (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds * * HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’ * HkTypeParam ::= {Annotation} [‘+’ | ‘-’] diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 817e7baf1c8c..a669d555617d 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1919,8 +1919,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) - if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) - else typedPolyFunctionValue(tree1, pt) + val tree2 = desugar.expandPolyFunctionContextBounds(tree1) + if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt) + else typedPolyFunctionValue(tree2, pt) def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala new file mode 100644 index 000000000000..66c177cf6c89 --- /dev/null +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -0,0 +1,15 @@ +import scala.language.experimental.modularity +import scala.language.future + + +trait Ord[X]: + def compare(x: X, y: X): Int + +val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +// type Comparer = [X: Ord] => (x: X, y: X) => Boolean +// val less2: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + +// type Cmp[X] = (x: X, y: X) => Boolean +// type Comparer2 = [X: Ord] => Cmp[X] +// val less3: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 From 8dc68c3303c70b8fbb8fbea2a127bd11f5667507 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 10:27:47 +0200 Subject: [PATCH 02/19] Handle named context bounds in poly function context bound desugaring --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 8 ++++++-- tests/pos/contextbounds-for-poly-functions.scala | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 901fbd1bb601..91adf3c97733 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1230,7 +1230,7 @@ object desugar { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => TypeDef(name, ContextBounds(bounds, List.empty)) } - var idx = -1 + var idx = 0 val collecedContextBounds = tparams.collect { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => // TOOD(kπ) Should we handle non empty normal bounds here? @@ -1238,7 +1238,11 @@ object desugar { }.flatMap { case (name, ctxBounds) => ctxBounds.map { ctxBound => idx = idx + 1 - makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) + ctxBound match + case ContextBoundTypeTree(_, _, ownName) => + ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + case _ => + makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) } } val contextFunctionResult = diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 66c177cf6c89..00feedd66d71 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -7,9 +7,11 @@ trait Ord[X]: val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + // type Comparer = [X: Ord] => (x: X, y: X) => Boolean -// val less2: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +// val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Cmp[X] = (x: X, y: X) => Boolean // type Comparer2 = [X: Ord] => Cmp[X] -// val less3: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 From 3ac5cec843d2f276cb851b67ea7d46871853bdde Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 16:12:28 +0200 Subject: [PATCH 03/19] Correctly-ish desugar poly function context bounds in function types --- .../src/dotty/tools/dotc/ast/Desugar.scala | 23 +++++++++++-------- .../src/dotty/tools/dotc/typer/Typer.scala | 19 ++++++++------- .../contextbounds-for-poly-functions.scala | 9 ++++---- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 91adf3c97733..5db72d2f5a09 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1226,31 +1226,36 @@ object desugar { */ def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction = val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked - val newTParams = tparams.map { + val newTParams = tparams.mapConserve { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => TypeDef(name, ContextBounds(bounds, List.empty)) + case t => t } var idx = 0 - val collecedContextBounds = tparams.collect { + val collectedContextBounds = tparams.collect { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => - // TOOD(kπ) Should we handle non empty normal bounds here? name -> ctxBounds }.flatMap { case (name, ctxBounds) => ctxBounds.map { ctxBound => idx = idx + 1 ctxBound match - case ContextBoundTypeTree(_, _, ownName) => - ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + case ctxBound @ ContextBoundTypeTree(tycon, paramName, ownName) => + if tree.isTerm then + ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + else + ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType case _ => makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) } } val contextFunctionResult = - if collecedContextBounds.isEmpty then - fun + if collectedContextBounds.isEmpty then fun else - Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span) - PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) + val mods = EmptyModifiers.withFlags(Given) + val erasedParams = collectedContextBounds.map(_ => false) + Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span) + if collectedContextBounds.isEmpty then tree + else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index a669d555617d..4dfd2c05c808 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -40,7 +40,7 @@ import annotation.tailrec import Implicits.* import util.Stats.record import config.Printers.{gadts, typr} -import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration} +import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration} import config.SourceVersion.* import rewrites.Rewrites, Rewrites.patch import staging.StagingLevel @@ -1145,7 +1145,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if templ1.parents.isEmpty && isFullyDefined(pt, ForceDegree.flipBottom) && isSkolemFree(pt) - && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity))) + && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))) then templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil) for case parent: RefTree <- templ1.parents do @@ -1720,7 +1720,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) else val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) - val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) + val args1 = args.mapConserve { + case cb: untpd.ContextBoundTypeTree => typed(cb) + case t => t + } + val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt) // if there are any erased classes, we need to re-do the typecheck. result match case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) => @@ -2451,12 +2455,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if tycon.tpe.typeParams.nonEmpty then val tycon0 = tycon.withType(tycon.tpe.etaCollapse) typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil)) - else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then + else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then val tparamSplice = untpd.TypedSplice(typedExpr(tparam)) typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice)))) else def selfNote = - if Feature.enabled(modularity) then + if Feature.enabled(Feature.modularity) then " and\ndoes not have an abstract type member named `Self` either" else "" errorTree(tree, @@ -2475,7 +2479,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked val refinements1 = impl.body val seen = mutable.Set[Symbol]() - for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions + for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions typr.println(s"adding refinement $refinement") checkRefinementNonCyclic(refinement, refineCls, seen) val rsym = refinement.symbol @@ -2489,7 +2493,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val member = refineCls.info.member(rsym.name) if (member.isOverloaded) report.error(OverloadInRefinement(rsym), refinement.srcPos) - } assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls) } @@ -4706,7 +4709,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName) case _ => errorTree(tree, em"cannot convert from $tree to an instance creation expression") - val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity)) + val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)) typed( untpd.Select( untpd.New(untpd.TypedSplice(tpt.withType(tycon))), diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 00feedd66d71..90bd01ce6b6d 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -5,12 +5,13 @@ import scala.language.future trait Ord[X]: def compare(x: X, y: X): Int -val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +// val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +// val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 -// type Comparer = [X: Ord] => (x: X, y: X) => Boolean -// val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean +type Comparer = [X: Ord] => (x: X, y: X) => Boolean +val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Cmp[X] = (x: X, y: X) => Boolean // type Comparer2 = [X: Ord] => Cmp[X] From 408aa74c158290a69f3411f12c719daaa45e8021 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 16:17:26 +0200 Subject: [PATCH 04/19] Fix pickling issue --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 2 +- tests/pos/contextbounds-for-poly-functions.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 5db72d2f5a09..12c237701d62 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1228,7 +1228,7 @@ object desugar { val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked val newTParams = tparams.mapConserve { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => - TypeDef(name, ContextBounds(bounds, List.empty)) + cpy.TypeDef(td)(name, bounds) case t => t } var idx = 0 diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 90bd01ce6b6d..6a3ec9935a65 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -5,10 +5,11 @@ import scala.language.future trait Ord[X]: def compare(x: X, y: X): Int -// val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -// val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +type CtxFunctionRef = Ord[Int] ?=> Boolean type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean type Comparer = [X: Ord] => (x: X, y: X) => Boolean val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 From 134c0150617a091d5ea482374c3a42e2572e78ad Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 16:47:13 +0200 Subject: [PATCH 05/19] Hide context bounds expansion for poly functions under modularity feature --- compiler/src/dotty/tools/dotc/parsing/Parsers.scala | 6 ++++-- compiler/src/dotty/tools/dotc/typer/Typer.scala | 3 ++- tests/pos/contextbounds-for-poly-functions.scala | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 7a5facf38b67..e54caff9f47d 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -68,7 +68,7 @@ object Parsers { def acceptsVariance = this == Class || this == CaseClass || this == Hk def acceptsCtxBounds = - !(this == Hk) + !(this == Type || this == Hk) def acceptsWildcard = this == Type || this == Hk @@ -3460,7 +3460,9 @@ object Parsers { else ident().toTypeName val hkparams = typeParamClauseOpt(ParamOwner.Hk) val bounds = - if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds() + if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) + else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name) + else typeBounds() TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods) } } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4dfd2c05c808..4a656c15a9ea 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1923,7 +1923,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) - val tree2 = desugar.expandPolyFunctionContextBounds(tree1) + val tree2 = if Feature.enabled(Feature.modularity) then desugar.expandPolyFunctionContextBounds(tree1) + else tree1 if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt) else typedPolyFunctionValue(tree2, pt) diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 6a3ec9935a65..c293fd0d9780 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -1,7 +1,6 @@ import scala.language.experimental.modularity import scala.language.future - trait Ord[X]: def compare(x: X, y: X): Int From 309034e5579dd5dcf2667d5cd096c1067681b724 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 25 Sep 2024 09:43:30 +0200 Subject: [PATCH 06/19] Small cleanup --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 14 +++++--------- tests/pos/contextbounds-for-poly-functions.scala | 2 ++ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 12c237701d62..d37af4aaedae 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1237,15 +1237,11 @@ object desugar { name -> ctxBounds }.flatMap { case (name, ctxBounds) => ctxBounds.map { ctxBound => - idx = idx + 1 - ctxBound match - case ctxBound @ ContextBoundTypeTree(tycon, paramName, ownName) => - if tree.isTerm then - ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) - else - ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType - case _ => - makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) + val ContextBoundTypeTree(tycon, paramName, ownName) = ctxBound: @unchecked + if tree.isTerm then + ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + else + ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType } } val contextFunctionResult = diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index c293fd0d9780..7da7405c9225 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -16,3 +16,5 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Cmp[X] = (x: X, y: X) => Boolean // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 \ No newline at end of file From 64bd03eff0d4a901de85b563e681e1ffc5d8eb24 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 25 Sep 2024 09:49:23 +0200 Subject: [PATCH 07/19] Add more test cases --- tests/pos/contextbounds-for-poly-functions.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 7da7405c9225..7db41628e57d 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -4,6 +4,9 @@ import scala.language.future trait Ord[X]: def compare(x: X, y: X): Int +trait Show[X]: + def show(x: X): String + val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 @@ -17,4 +20,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 \ No newline at end of file +val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 From 5196efde72f40ebca4f516cf4febace2ce9cda15 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 2 Oct 2024 17:23:15 +0200 Subject: [PATCH 08/19] Change the implementation of context bound expansion for poly functions to reuse some of the existing context bound expansion --- .../src/dotty/tools/dotc/ast/Desugar.scala | 58 ++++++++----------- .../src/dotty/tools/dotc/typer/Typer.scala | 34 ++++++----- .../contextbounds-for-poly-functions.scala | 27 +++++++++ 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index d37af4aaedae..488755f81848 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -52,6 +52,11 @@ object desugar { */ val ContextBoundParam: Property.Key[Unit] = Property.StickyKey() + /** An attachment key to indicate that a DefDef is a poly function apply + * method definition. + */ + val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + /** What static check should be applied to a Match? */ enum MatchCheck { case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom @@ -337,7 +342,8 @@ object desugar { cpy.DefDef(meth)( name = normalizeName(meth, tpt).asTermName, paramss = paramssNoContextBounds), - evidenceParamBuf.toList) + evidenceParamBuf.toList + ) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -508,7 +514,19 @@ object desugar { case Nil => params :: Nil - cpy.DefDef(meth)(paramss = recur(meth.paramss)) + if meth.hasAttachment(PolyFunctionApply) then + meth.removeAttachment(PolyFunctionApply) + val paramTpts = params.map(_.tpt) + val paramNames = params.map(_.name) + val paramsErased = params.map(_.mods.flags.is(Erased)) + if ctx.mode.is(Mode.Type) then + val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased) + cpy.DefDef(meth)(tpt = ctxFunction) + else + val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased) + cpy.DefDef(meth)(rhs = ctxFunction) + else + cpy.DefDef(meth)(paramss = recur(meth.paramss)) end addEvidenceParams /** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */ @@ -1221,38 +1239,6 @@ object desugar { case _ => body cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction] - /** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R - * Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R - */ - def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction = - val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked - val newTParams = tparams.mapConserve { - case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => - cpy.TypeDef(td)(name, bounds) - case t => t - } - var idx = 0 - val collectedContextBounds = tparams.collect { - case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => - name -> ctxBounds - }.flatMap { case (name, ctxBounds) => - ctxBounds.map { ctxBound => - val ContextBoundTypeTree(tycon, paramName, ownName) = ctxBound: @unchecked - if tree.isTerm then - ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) - else - ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType - } - } - val contextFunctionResult = - if collectedContextBounds.isEmpty then fun - else - val mods = EmptyModifiers.withFlags(Given) - val erasedParams = collectedContextBounds.map(_ => false) - Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span) - if collectedContextBounds.isEmpty then tree - else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) - /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ @@ -1275,7 +1261,9 @@ object desugar { }.toList RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic) + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, ()) )).withSpan(tree.span) end makePolyFunctionType diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4a656c15a9ea..00f22c874f7c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -40,7 +40,7 @@ import annotation.tailrec import Implicits.* import util.Stats.record import config.Printers.{gadts, typr} -import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration} +import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration} import config.SourceVersion.* import rewrites.Rewrites, Rewrites.patch import staging.StagingLevel @@ -53,6 +53,7 @@ import config.MigrationVersion import transform.CheckUnused.OriginalName import scala.annotation.constructorOnly +import dotty.tools.dotc.ast.desugar.PolyFunctionApply object Typer { @@ -1145,7 +1146,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if templ1.parents.isEmpty && isFullyDefined(pt, ForceDegree.flipBottom) && isSkolemFree(pt) - && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))) + && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity))) then templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil) for case parent: RefTree <- templ1.parents do @@ -1720,11 +1721,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) else val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) - val args1 = args.mapConserve { - case cb: untpd.ContextBoundTypeTree => typed(cb) - case t => t - } - val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt) + // val args1 = args.mapConserve { + // case cb: untpd.ContextBoundTypeTree => typed(cb) + // case t => t + // } + val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) // if there are any erased classes, we need to re-do the typecheck. result match case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) => @@ -1923,10 +1924,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) - val tree2 = if Feature.enabled(Feature.modularity) then desugar.expandPolyFunctionContextBounds(tree1) - else tree1 - if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt) - else typedPolyFunctionValue(tree2, pt) + if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) + else typedPolyFunctionValue(tree1, pt) def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked @@ -1951,7 +1950,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val resultTpt = untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) - val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) + val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) + defdef.putAttachment(PolyFunctionApply, ()) typed(desugared, pt) else val msg = @@ -1959,7 +1959,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer |Expected type should be a polymorphic function with the same number of type and value parameters.""" errorTree(EmptyTree, msg, tree.srcPos) case _ => - val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) + val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) + defdef.putAttachment(PolyFunctionApply, ()) typed(desugared, pt) end typedPolyFunctionValue @@ -2456,12 +2457,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if tycon.tpe.typeParams.nonEmpty then val tycon0 = tycon.withType(tycon.tpe.etaCollapse) typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil)) - else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then + else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then val tparamSplice = untpd.TypedSplice(typedExpr(tparam)) typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice)))) else def selfNote = - if Feature.enabled(Feature.modularity) then + if Feature.enabled(modularity) then " and\ndoes not have an abstract type member named `Self` either" else "" errorTree(tree, @@ -3610,6 +3611,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked + println(i"make contextual function $tree / $pt") val paramNamesOrNil = pt match case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames case _ => Nil @@ -4710,7 +4712,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName) case _ => errorTree(tree, em"cannot convert from $tree to an instance creation expression") - val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)) + val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity)) typed( untpd.Select( untpd.New(untpd.TypedSplice(tpt.withType(tycon))), diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 7db41628e57d..a5a035754b08 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -7,10 +7,18 @@ trait Ord[X]: trait Show[X]: def show(x: X): String +val less0: [X: Ord] => (X, X) => Boolean = ??? + val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less1_type_test: [X: Ord] => (X, X) => Boolean = + [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +val less2_type_test: [X: Ord as ord] => (X, X) => Boolean = + [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + type CtxFunctionRef = Ord[Int] ?=> Boolean type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean type Comparer = [X: Ord] => (x: X, y: X) => Boolean @@ -20,12 +28,31 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +// type CmpWeak[X] = (x: X, y: X) => Boolean +// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X] +// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less5_type_test: [X: [X] =>> Ord[X]] => (X, X) => Boolean = + [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less6_type_test: [X: {Ord, Show}] => (X, X) => Boolean = + [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 +val less7_type_test: [X: {Ord as ord, Show}] => (X, X) => Boolean = + [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 + val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less8_type_test: [X: {Ord, Show as show}] => (X, X) => Boolean = + [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean = + [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 From 5f0d4a7205ce80d3576c92df1c7243e67031f430 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Thu, 3 Oct 2024 15:49:53 +0200 Subject: [PATCH 09/19] Add support for some type aliases, when expanding context bounds for poly functions --- .../src/dotty/tools/dotc/ast/Desugar.scala | 26 ++++++--- .../src/dotty/tools/dotc/typer/Typer.scala | 56 ++++++++++++++----- .../contextbounds-for-poly-functions.scala | 9 ++- 3 files changed, 64 insertions(+), 27 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 488755f81848..56e519737cb1 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -55,7 +55,7 @@ object desugar { /** An attachment key to indicate that a DefDef is a poly function apply * method definition. */ - val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey() /** What static check should be applied to a Match? */ enum MatchCheck { @@ -514,17 +514,25 @@ object desugar { case Nil => params :: Nil + // TODO(kπ) is this enough? SHould this be a TreeTraverse-thing? + def pushDownEvidenceParams(tree: Tree): Tree = tree match + case Function(params, body) => + cpy.Function(tree)(params, pushDownEvidenceParams(body)) + case Block(stats, expr) => + cpy.Block(tree)(stats, pushDownEvidenceParams(expr)) + case tree => + val paramTpts = params.map(_.tpt) + val paramNames = params.map(_.name) + val paramsErased = params.map(_.mods.flags.is(Erased)) + makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) + if meth.hasAttachment(PolyFunctionApply) then meth.removeAttachment(PolyFunctionApply) - val paramTpts = params.map(_.tpt) - val paramNames = params.map(_.name) - val paramsErased = params.map(_.mods.flags.is(Erased)) + // (kπ): deffer this until we can type the result? if ctx.mode.is(Mode.Type) then - val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased) - cpy.DefDef(meth)(tpt = ctxFunction) + cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) else - val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased) - cpy.DefDef(meth)(rhs = ctxFunction) + cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs)) else cpy.DefDef(meth)(paramss = recur(meth.paramss)) end addEvidenceParams @@ -1263,7 +1271,7 @@ object desugar { RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, ()) + .withAttachment(PolyFunctionApply, List.empty) )).withSpan(tree.span) end makePolyFunctionType diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 00f22c874f7c..56e62bccf83b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -53,7 +53,6 @@ import config.MigrationVersion import transform.CheckUnused.OriginalName import scala.annotation.constructorOnly -import dotty.tools.dotc.ast.desugar.PolyFunctionApply object Typer { @@ -1951,7 +1950,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) - defdef.putAttachment(PolyFunctionApply, ()) + defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) else val msg = @@ -1960,7 +1959,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer errorTree(EmptyTree, msg, tree.srcPos) case _ => val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) - defdef.putAttachment(PolyFunctionApply, ()) + defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) end typedPolyFunctionValue @@ -3588,30 +3587,57 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case xtree => typedUnnamed(xtree) val unsimplifiedType = result.tpe - simplify(result, pt, locked) - result.tpe.stripTypeVar match + val result1 = simplify(result, pt, locked) + result1.tpe.stripTypeVar match case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos) - case _ => result + case _ => result1 catch case ex: TypeError => handleTypeError(ex) } } + private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { + case tpe: MethodType => + MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + case tpe: PolyType => + PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + case tpe: RefinedType => + // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement + RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)) + case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => + AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) + case tpe => + val paramNames = params.map(_.name) + val paramTpts = params.map(_.tpt) + val paramsErased = params.map(_.mods.flags.is(Erased)) + val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span) + typed(ctxFunction).tpe + } + + private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { + tree.getAttachment(desugar.PolyFunctionApply) match + case Some(params) if params.nonEmpty => + tree.removeAttachment(desugar.PolyFunctionApply) + val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) + TypeTree(tpe).withSpan(tree.span) -> tpe + case _ => tree -> pt + } + /** Interpolate and simplify the type of the given tree. */ - protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = - if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying - if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied - || tree.isDef // ... unless tree is a definition + protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = + val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt) + if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying + if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied + || tree1.isDef // ... unless tree is a definition then - interpolateTypeVars(tree, pt, locked) - val simplified = tree.tpe.simplified - if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743 + interpolateTypeVars(tree1, pt1, locked) + val simplified = tree1.tpe.simplified + if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743 tree.overwriteType(simplified) - tree + tree1 protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked - println(i"make contextual function $tree / $pt") val paramNamesOrNil = pt match case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames case _ => Nil diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index a5a035754b08..adaf6c035406 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -28,9 +28,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -// type CmpWeak[X] = (x: X, y: X) => Boolean -// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X] -// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +type CmpWeak[X] = X => Boolean +type Comparer2Weak = [X: Ord] => X => CmpWeak[X] +val less4_0: [X: Ord] => X => X => Boolean = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +val less4: Comparer2Weak = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 From a736592dcb95a35ddeaef12a84324733b0e4a7b5 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 4 Oct 2024 13:21:40 +0200 Subject: [PATCH 10/19] Make the expandion of context bounds for poly types slightly more elegant --- .../src/dotty/tools/dotc/ast/Desugar.scala | 55 ++++++++++--------- .../src/dotty/tools/dotc/typer/Typer.scala | 51 ++++++++++++++--- .../contextbounds-for-poly-functions.scala | 2 +- 3 files changed, 75 insertions(+), 33 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 56e519737cb1..1d2fd32fe103 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -527,8 +527,7 @@ object desugar { makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) if meth.hasAttachment(PolyFunctionApply) then - meth.removeAttachment(PolyFunctionApply) - // (kπ): deffer this until we can type the result? + // meth.removeAttachment(PolyFunctionApply) if ctx.mode.is(Mode.Type) then cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) else @@ -1250,29 +1249,35 @@ object desugar { /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ - def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = - val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked - val paramFlags = fun match - case fun: FunctionWithMods => - // TODO: make use of this in the desugaring when pureFuns is enabled. - // val isImpure = funFlags.is(Impure) - - // Function flags to be propagated to each parameter in the desugared method type. - val givenFlag = fun.mods.flags.toTermFlags & Given - fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) - case _ => - vparamTypes.map(_ => EmptyFlags) - - val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { - case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) - case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) - }.toList - - RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) - .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, List.empty) - )).withSpan(tree.span) + def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match + case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) => + val paramFlags = fun match + case fun: FunctionWithMods => + // TODO: make use of this in the desugaring when pureFuns is enabled. + // val isImpure = funFlags.is(Impure) + + // Function flags to be propagated to each parameter in the desugared method type. + val givenFlag = fun.mods.flags.toTermFlags & Given + fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) + case _ => + vparamTypes.map(_ => EmptyFlags) + + val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { + case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) + case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + }.toList + + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, List.empty) + )).withSpan(tree.span) + case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) => + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, List.empty) + )).withSpan(tree.span) end makePolyFunctionType /** Invent a name for an anonympus given of type or template `impl`. */ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 56e62bccf83b..2951452e44f9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3598,14 +3598,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { case tpe: MethodType => - MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) case tpe: PolyType => - PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) case tpe: RefinedType => - // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement - RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)) + tpe.derivedRefinedType( + pushDownDeferredEvidenceParams(tpe.parent, params, span), + tpe.refinedName, + pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span) + ) case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => - AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) + tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) case tpe => val paramNames = params.map(_.name) val paramTpts = params.map(_.tpt) @@ -3614,18 +3617,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typed(ctxFunction).tpe } - private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { + private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match { + case tpe: MethodType => + tpe.paramNames -> tpe.paramInfos + case tpe: RefinedType if defn.isFunctionType(tpe.parent) => + extractTopMethodTermParams(tpe.refinedInfo) + case _ => + Nil -> Nil + } + + private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match { + case tpe: MethodType => + tpe.resultType + case tpe: RefinedType if defn.isFunctionType(tpe.parent) => + tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo)) + case tpe: AppliedType if defn.isFunctionType(tpe) => + tpe.args.last + case _ => + tpe + } + + private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match { + case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 => + val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe) + val newTpe = removeTopMethodTermParams(defdef.tpt.tpe) + val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam)) + val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe)) + val nestedCtx = ctx.fresh.setNewTyperState() + typed(newDefDef)(using nestedCtx) + case _ => tree + } + + private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { tree.getAttachment(desugar.PolyFunctionApply) match case Some(params) if params.nonEmpty => tree.removeAttachment(desugar.PolyFunctionApply) val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) TypeTree(tpe).withSpan(tree.span) -> tpe + // case Some(params) if params.isEmpty => + // println(s"tree: $tree") + // healToPolyFunctionType(tree) -> pt case _ => tree -> pt } /** Interpolate and simplify the type of the given tree. */ protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = - val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt) + val (tree1, pt1) = addDeferredEvidenceParams(tree, pt) if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied || tree1.isDef // ... unless tree is a definition diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index adaf6c035406..8c7bead36633 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean type Comparer2Weak = [X: Ord] => X => CmpWeak[X] val less4_0: [X: Ord] => X => X => Boolean = [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 -val less4: Comparer2Weak = +val less4_1: Comparer2Weak = [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 From 0af839783a356e2d42133585f6f63759d6df156c Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 8 Oct 2024 10:51:53 +0200 Subject: [PATCH 11/19] Add more aliases tests for context bounds with poly functions --- tests/pos/contextbounds-for-poly-functions.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 8c7bead36633..a3b79043c01a 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -24,6 +24,12 @@ type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean type Comparer = [X: Ord] => (x: X, y: X) => Boolean val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +type CmpRest[X] = X => Boolean +type CmpMid[X] = X => CmpRest[X] +type Cmp3 = [X: Ord] => X => CmpMid[X] +val lessCmp3: Cmp3 = [X: Ord] => (x: X) => (y: X) => (z: X) => summon[Ord[X]].compare(x, y) < 0 +val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.compare(x, y) < 0 + // type Cmp[X] = (x: X, y: X) => Boolean // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 From 9c66069d330423f46d340c99f630ef8c375dae1d Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 14 Oct 2024 12:18:25 +0200 Subject: [PATCH 12/19] Bring back the restriction for requiring value parameters in poly function type definitions --- .../src/dotty/tools/dotc/ast/Desugar.scala | 2 - .../src/dotty/tools/dotc/typer/Typer.scala | 47 +++++-------------- 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 1d2fd32fe103..e0a906e06dcc 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -514,7 +514,6 @@ object desugar { case Nil => params :: Nil - // TODO(kπ) is this enough? SHould this be a TreeTraverse-thing? def pushDownEvidenceParams(tree: Tree): Tree = tree match case Function(params, body) => cpy.Function(tree)(params, pushDownEvidenceParams(body)) @@ -527,7 +526,6 @@ object desugar { makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) if meth.hasAttachment(PolyFunctionApply) then - // meth.removeAttachment(PolyFunctionApply) if ctx.mode.is(Mode.Type) then cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) else diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 2951452e44f9..076fc2e4369f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3596,6 +3596,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } } + /** Push down the deferred evidence parameters up until the result type is not + * a method type, poly type or a function type + */ private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { case tpe: MethodType => tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) @@ -3617,46 +3620,22 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typed(ctxFunction).tpe } - private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match { - case tpe: MethodType => - tpe.paramNames -> tpe.paramInfos - case tpe: RefinedType if defn.isFunctionType(tpe.parent) => - extractTopMethodTermParams(tpe.refinedInfo) - case _ => - Nil -> Nil - } - - private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match { - case tpe: MethodType => - tpe.resultType - case tpe: RefinedType if defn.isFunctionType(tpe.parent) => - tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo)) - case tpe: AppliedType if defn.isFunctionType(tpe) => - tpe.args.last - case _ => - tpe - } - - private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match { - case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 => - val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe) - val newTpe = removeTopMethodTermParams(defdef.tpt.tpe) - val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam)) - val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe)) - val nestedCtx = ctx.fresh.setNewTyperState() - typed(newDefDef)(using nestedCtx) - case _ => tree - } - + /** If the tree has a `PolyFunctionApply` attachment, add the deferred + * evidence parameters as the last argument list before the result type. This + * follows aliases, so the following two types will be expanded to (up to the + * context bound encoding): + * type CmpWeak[X] = X => Boolean + * type Comparer2Weak = [X: Ord] => X => CmpWeak[X] + * ===> + * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=> + * Ord[X] => Boolean + */ private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { tree.getAttachment(desugar.PolyFunctionApply) match case Some(params) if params.nonEmpty => tree.removeAttachment(desugar.PolyFunctionApply) val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) TypeTree(tpe).withSpan(tree.span) -> tpe - // case Some(params) if params.isEmpty => - // println(s"tree: $tree") - // healToPolyFunctionType(tree) -> pt case _ => tree -> pt } From ec6d7effc62129508c68f0337d9c854f94390645 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 18 Oct 2024 10:10:06 +0200 Subject: [PATCH 13/19] Cleanup dead code --- compiler/src/dotty/tools/dotc/typer/Typer.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 076fc2e4369f..4ae706fd7b0c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1720,10 +1720,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) else val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) - // val args1 = args.mapConserve { - // case cb: untpd.ContextBoundTypeTree => typed(cb) - // case t => t - // } val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) // if there are any erased classes, we need to re-do the typecheck. result match From dfa92409a1f6ed9ea0226c145ef1f383e67c43bd Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Thu, 14 Nov 2024 11:25:08 +0100 Subject: [PATCH 14/19] Reuse addEvidenceParams logic, but no aliases --- .../src/dotty/tools/dotc/ast/Desugar.scala | 89 +++++++++++---- .../src/dotty/tools/dotc/typer/Typer.scala | 101 +++++++++--------- .../contextbounds-for-poly-functions.scala | 33 ++++-- 3 files changed, 147 insertions(+), 76 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e0a906e06dcc..d82150d8f9da 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -22,6 +22,7 @@ import parsing.Parsers import scala.annotation.internal.sharable import scala.annotation.threadUnsafe +import dotty.tools.dotc.quoted.QuoteUtils.treeOwner object desugar { import untpd.* @@ -52,8 +53,12 @@ object desugar { */ val ContextBoundParam: Property.Key[Unit] = Property.StickyKey() - /** An attachment key to indicate that a DefDef is a poly function apply - * method definition. + /** When first desugaring a PolyFunction, this attachment is added to the + * PolyFunction `apply` method with an empty list value. + * + * Afterwards, the attachment is added to poly function type trees, with the + * list of their context bounds. + * //TODO(kπ) see if it has to be updated */ val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey() @@ -497,9 +502,9 @@ object desugar { case Ident(name: TermName) => boundNames.contains(name) case _ => false - def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match + def recur(mparamss: List[ParamClause]): (List[ParamClause], List[ParamClause]) = mparamss match case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) => - params :: mparamss + (params :: Nil) -> mparamss case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => val normParams = if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then @@ -508,30 +513,71 @@ object desugar { param.withMods(param.mods.withFlags(normFlags)) .showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar) else params - (normParams ++ mparams) :: Nil + ((normParams ++ mparams) :: Nil) -> Nil case mparams :: mparamss1 => - mparams :: recur(mparamss1) + val (fst, snd) = recur(mparamss1) + (mparams :: fst) -> snd case Nil => - params :: Nil - - def pushDownEvidenceParams(tree: Tree): Tree = tree match - case Function(params, body) => - cpy.Function(tree)(params, pushDownEvidenceParams(body)) - case Block(stats, expr) => - cpy.Block(tree)(stats, pushDownEvidenceParams(expr)) - case tree => + Nil -> (params :: Nil) + + // def pushDownEvidenceParams(tree: Tree): Tree = tree match + // case Function(mparams, body) if mparams.collect { case v: ValDef => v }.exists(referencesBoundName) => + // ctxFunctionWithParams(tree) + // case Function(mparams, body) => + // cpy.Function(tree)(mparams, pushDownEvidenceParams(body)) + // case Block(stats, expr) => + // cpy.Block(tree)(stats, pushDownEvidenceParams(expr)) + // case tree => + // ctxFunctionWithParams(tree) + + // def ctxFunctionWithParams(tree: Tree): Tree = + // val paramTpts = params.map(_.tpt) + // val paramNames = params.map(_.name) + // val paramsErased = params.map(_.mods.flags.is(Erased)) + // Function(params, tree).withSpan(tree.span).withAttachmentsFrom(tree) + + def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match + case Nil => rhs + case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) => val paramTpts = params.map(_.tpt) val paramNames = params.map(_.name) val paramsErased = params.map(_.mods.flags.is(Erased)) - makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) + makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span) + case head :: rest => + Function(head, functionsOf(rest, rhs)) if meth.hasAttachment(PolyFunctionApply) then - if ctx.mode.is(Mode.Type) then - cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) - else - cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs)) + println(i"${recur(meth.paramss)}") + recur(meth.paramss) match + case (paramsFst, Nil) => + cpy.DefDef(meth)(paramss = paramsFst) + case (paramsFst, paramsSnd) => + if ctx.mode.is(Mode.Type) then + cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) + else + cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) + + // if ctx.mode.is(Mode.Type) then + // meth.removeAttachment(PolyFunctionApply) + // // should be kept on meth to see the current param types? + // meth.tpt.putAttachment(PolyFunctionApply, params) + // val newParamss = recur(meth.paramss) + // println(i"added PolyFunctionApply to ${meth.name}.tpt: ${meth.tpt} with $params") + // println(i"new paramss: $newParamss") + // meth + // else + // val newParamss = recur(meth.paramss) + // println(i"added PolyFunctionApply to ${meth.name} with $params") + // println(i"new paramss: $newParamss") + // val DefDef(_, mparamss, _ , _) = meth: @unchecked + // val tparams :: ValDefs(vparams) :: Nil = mparamss: @unchecked + // if vparams.exists(referencesBoundName) then + // cpy.DefDef(meth)(paramss = tparams :: params :: Nil, rhs = Function(vparams, meth.rhs)) + // else + // cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs)) else - cpy.DefDef(meth)(paramss = recur(meth.paramss)) + val (paramsFst, paramsSnd) = recur(meth.paramss) + cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd) end addEvidenceParams /** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */ @@ -1265,17 +1311,20 @@ object desugar { case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) }.toList + vparams.foreach(p => println(i" $p, ${p.mods.flags.flagsString}")) RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) .withFlags(Synthetic) .withAttachment(PolyFunctionApply, List.empty) )).withSpan(tree.span) + .withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty)) case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) => RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: Nil, res, EmptyTree) .withFlags(Synthetic) .withAttachment(PolyFunctionApply, List.empty) )).withSpan(tree.span) + .withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty)) end makePolyFunctionType /** Invent a name for an anonympus given of type or template `impl`. */ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4ae706fd7b0c..cfa921f500a2 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3592,61 +3592,62 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } } - /** Push down the deferred evidence parameters up until the result type is not - * a method type, poly type or a function type - */ - private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { - case tpe: MethodType => - tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) - case tpe: PolyType => - tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) - case tpe: RefinedType => - tpe.derivedRefinedType( - pushDownDeferredEvidenceParams(tpe.parent, params, span), - tpe.refinedName, - pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span) - ) - case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => - tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) - case tpe => - val paramNames = params.map(_.name) - val paramTpts = params.map(_.tpt) - val paramsErased = params.map(_.mods.flags.is(Erased)) - val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span) - typed(ctxFunction).tpe - } - - /** If the tree has a `PolyFunctionApply` attachment, add the deferred - * evidence parameters as the last argument list before the result type. This - * follows aliases, so the following two types will be expanded to (up to the - * context bound encoding): - * type CmpWeak[X] = X => Boolean - * type Comparer2Weak = [X: Ord] => X => CmpWeak[X] - * ===> - * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=> - * Ord[X] => Boolean - */ - private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { - tree.getAttachment(desugar.PolyFunctionApply) match - case Some(params) if params.nonEmpty => - tree.removeAttachment(desugar.PolyFunctionApply) - val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) - TypeTree(tpe).withSpan(tree.span) -> tpe - case _ => tree -> pt - } + // /** Push down the deferred evidence parameters up until the result type is not + // * a method type, poly type or a function type + // */ + // private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = + // tpe.dealias match { + // case tpe if tpe.baseClasses.contains(defn.PolyFunctionClass) => + // attachEvidenceParams(tpe, params, span) + // case tpe: MethodType => + // tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + // case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => + // tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) + // case tpe => + // attachEvidenceParams(tpe, params, span) + // } + + // /** (params) ?=> tpe */ + // private def attachEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = + // val paramNames = params.map(_.name) + // val paramTpts = params.map(_.tpt) + // val paramsErased = params.map(_.mods.flags.is(Erased)) + // val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span) + // typed(ctxFunction).tpe + + // /** If the tree has a `PolyFunctionApply` attachment, add the deferred + // * evidence parameters as the last argument list before the result type or a next poly type. + // * This follows aliases, so the following two types will be expanded to (up to the + // * context bound encoding): + // * type CmpWeak[X] = X => Boolean + // * type Comparer2Weak = [X: Ord] => X => CmpWeak[X] + // * ===> + // * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=> + // * Ord[X] => Boolean + // */ + // private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { + // tree.getAttachment(desugar.PolyFunctionApply) match + // case Some(params) if params.nonEmpty => + // tree.putAttachment(desugar.PolyFunctionApply, Nil) + // val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) + // TypeTree(tpe).withSpan(tree.span) -> tpe + // case Some(params) => + // tree -> pt + // case _ => tree -> pt + // } /** Interpolate and simplify the type of the given tree. */ protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = - val (tree1, pt1) = addDeferredEvidenceParams(tree, pt) - if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying - if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied - || tree1.isDef // ... unless tree is a definition + // val (tree1, pt1) = addDeferredEvidenceParams(tree, pt) + if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying + if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied + || tree.isDef // ... unless tree is a definition then - interpolateTypeVars(tree1, pt1, locked) - val simplified = tree1.tpe.simplified - if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743 + interpolateTypeVars(tree, pt, locked) + val simplified = tree.tpe.simplified + if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743 tree.overwriteType(simplified) - tree1 + tree protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index a3b79043c01a..44eb978b6c52 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -3,6 +3,7 @@ import scala.language.future trait Ord[X]: def compare(x: X, y: X): Int + type T trait Show[X]: def show(x: X): String @@ -11,6 +12,8 @@ val less0: [X: Ord] => (X, X) => Boolean = ??? val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +type PolyTest1 = [X] => X => Ord[X] ?=> Boolean + val less1_type_test: [X: Ord] => (X, X) => Boolean = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 @@ -34,12 +37,12 @@ val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.comp // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -type CmpWeak[X] = X => Boolean -type Comparer2Weak = [X: Ord] => X => CmpWeak[X] -val less4_0: [X: Ord] => X => X => Boolean = - [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 -val less4_1: Comparer2Weak = - [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +// type CmpWeak[X] = X => Boolean +// type Comparer2Weak = [X: Ord] => X => CmpWeak[X] +// val less4_0: [X: Ord] => X => X => Boolean = +// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +// val less4_1: Comparer2Weak = +// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 @@ -65,3 +68,21 @@ val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +type CmpNested = [X: Ord] => X => [Y: Ord] => Y => Boolean +val less10: CmpNested = [X: Ord] => (x: X) => [Y: Ord] => (y: Y) => true +val less10Explicit: CmpNested = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> true + +// type CmpAlias[X] = X => Boolean +// type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y] + +// val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true +// val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true + +val notationalExample: [X: Ord] => X => [Y: Ord] => Y => Int = + [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> 1 + +val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x +type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean +type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean +val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true From 7755e3bc166637f6753df2d2f74b127dc97b35f8 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Thu, 14 Nov 2024 11:46:12 +0100 Subject: [PATCH 15/19] Cleanup context bounds for poly functions implementation, make the implementation consistent with addEvidenceParams --- .../src/dotty/tools/dotc/ast/Desugar.scala | 79 ++++--------------- .../src/dotty/tools/dotc/typer/Typer.scala | 45 ----------- .../contextbounds-for-poly-functions.scala | 20 ++--- 3 files changed, 26 insertions(+), 118 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index d82150d8f9da..768d598987f0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -53,14 +53,9 @@ object desugar { */ val ContextBoundParam: Property.Key[Unit] = Property.StickyKey() - /** When first desugaring a PolyFunction, this attachment is added to the - * PolyFunction `apply` method with an empty list value. - * - * Afterwards, the attachment is added to poly function type trees, with the - * list of their context bounds. - * //TODO(kπ) see if it has to be updated + /** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way */ - val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey() + val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() /** What static check should be applied to a Match? */ enum MatchCheck { @@ -520,22 +515,6 @@ object desugar { case Nil => Nil -> (params :: Nil) - // def pushDownEvidenceParams(tree: Tree): Tree = tree match - // case Function(mparams, body) if mparams.collect { case v: ValDef => v }.exists(referencesBoundName) => - // ctxFunctionWithParams(tree) - // case Function(mparams, body) => - // cpy.Function(tree)(mparams, pushDownEvidenceParams(body)) - // case Block(stats, expr) => - // cpy.Block(tree)(stats, pushDownEvidenceParams(expr)) - // case tree => - // ctxFunctionWithParams(tree) - - // def ctxFunctionWithParams(tree: Tree): Tree = - // val paramTpts = params.map(_.tpt) - // val paramNames = params.map(_.name) - // val paramsErased = params.map(_.mods.flags.is(Erased)) - // Function(params, tree).withSpan(tree.span).withAttachmentsFrom(tree) - def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match case Nil => rhs case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) => @@ -543,38 +522,21 @@ object desugar { val paramNames = params.map(_.name) val paramsErased = params.map(_.mods.flags.is(Erased)) makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span) - case head :: rest => + case ValDefs(head) :: rest => Function(head, functionsOf(rest, rhs)) + case head :: _ => + assert(false, i"unexpected type parameters when adding evidence parameters to $meth") + EmptyTree if meth.hasAttachment(PolyFunctionApply) then - println(i"${recur(meth.paramss)}") - recur(meth.paramss) match - case (paramsFst, Nil) => - cpy.DefDef(meth)(paramss = paramsFst) - case (paramsFst, paramsSnd) => - if ctx.mode.is(Mode.Type) then - cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) - else - cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) - - // if ctx.mode.is(Mode.Type) then - // meth.removeAttachment(PolyFunctionApply) - // // should be kept on meth to see the current param types? - // meth.tpt.putAttachment(PolyFunctionApply, params) - // val newParamss = recur(meth.paramss) - // println(i"added PolyFunctionApply to ${meth.name}.tpt: ${meth.tpt} with $params") - // println(i"new paramss: $newParamss") - // meth - // else - // val newParamss = recur(meth.paramss) - // println(i"added PolyFunctionApply to ${meth.name} with $params") - // println(i"new paramss: $newParamss") - // val DefDef(_, mparamss, _ , _) = meth: @unchecked - // val tparams :: ValDefs(vparams) :: Nil = mparamss: @unchecked - // if vparams.exists(referencesBoundName) then - // cpy.DefDef(meth)(paramss = tparams :: params :: Nil, rhs = Function(vparams, meth.rhs)) - // else - // cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs)) + meth.removeAttachment(PolyFunctionApply) + // for PolyFunctions we are limited to a single term param list, so we reuse the recur logic to compute the new parameter lists + // and then we add the other parameter lists as function types to the return type + val (paramsFst, paramsSnd) = recur(meth.paramss) + if ctx.mode.is(Mode.Type) then + cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) + else + cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) else val (paramsFst, paramsSnd) = recur(meth.paramss) cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd) @@ -1293,7 +1255,7 @@ object desugar { /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ - def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match + def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) => val paramFlags = fun match case fun: FunctionWithMods => @@ -1311,20 +1273,11 @@ object desugar { case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) }.toList - vparams.foreach(p => println(i" $p, ${p.mods.flags.flagsString}")) RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, List.empty) - )).withSpan(tree.span) - .withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty)) - case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) => - RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: Nil, res, EmptyTree) - .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, List.empty) + .withAttachment(PolyFunctionApply, ()) )).withSpan(tree.span) - .withAttachment(PolyFunctionApply, tree.attachmentOrElse(PolyFunctionApply, List.empty)) end makePolyFunctionType /** Invent a name for an anonympus given of type or template `impl`. */ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index cfa921f500a2..f7610520f61c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3592,53 +3592,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } } - // /** Push down the deferred evidence parameters up until the result type is not - // * a method type, poly type or a function type - // */ - // private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = - // tpe.dealias match { - // case tpe if tpe.baseClasses.contains(defn.PolyFunctionClass) => - // attachEvidenceParams(tpe, params, span) - // case tpe: MethodType => - // tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span)) - // case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => - // tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) - // case tpe => - // attachEvidenceParams(tpe, params, span) - // } - - // /** (params) ?=> tpe */ - // private def attachEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = - // val paramNames = params.map(_.name) - // val paramTpts = params.map(_.tpt) - // val paramsErased = params.map(_.mods.flags.is(Erased)) - // val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span) - // typed(ctxFunction).tpe - - // /** If the tree has a `PolyFunctionApply` attachment, add the deferred - // * evidence parameters as the last argument list before the result type or a next poly type. - // * This follows aliases, so the following two types will be expanded to (up to the - // * context bound encoding): - // * type CmpWeak[X] = X => Boolean - // * type Comparer2Weak = [X: Ord] => X => CmpWeak[X] - // * ===> - // * type CmpWeak[X] = X => Boolean type Comparer2Weak = [X] => X => X ?=> - // * Ord[X] => Boolean - // */ - // private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { - // tree.getAttachment(desugar.PolyFunctionApply) match - // case Some(params) if params.nonEmpty => - // tree.putAttachment(desugar.PolyFunctionApply, Nil) - // val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) - // TypeTree(tpe).withSpan(tree.span) -> tpe - // case Some(params) => - // tree -> pt - // case _ => tree -> pt - // } - /** Interpolate and simplify the type of the given tree. */ protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = - // val (tree1, pt1) = addDeferredEvidenceParams(tree, pt) if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied || tree.isDef // ... unless tree is a definition diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 44eb978b6c52..6fadcda2b43e 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -37,12 +37,12 @@ val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.comp // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -// type CmpWeak[X] = X => Boolean -// type Comparer2Weak = [X: Ord] => X => CmpWeak[X] -// val less4_0: [X: Ord] => X => X => Boolean = -// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 -// val less4_1: Comparer2Weak = -// [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +type CmpWeak[X] = X => Boolean +type Comparer2Weak = [X: Ord] => X => CmpWeak[X] +val less4_0: [X: Ord] => X => X => Boolean = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +val less4_1: Comparer2Weak = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 @@ -73,11 +73,11 @@ type CmpNested = [X: Ord] => X => [Y: Ord] => Y => Boolean val less10: CmpNested = [X: Ord] => (x: X) => [Y: Ord] => (y: Y) => true val less10Explicit: CmpNested = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> true -// type CmpAlias[X] = X => Boolean -// type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y] +type CmpAlias[X] = X => Boolean +type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y] -// val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true -// val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true +val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true +val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true val notationalExample: [X: Ord] => X => [Y: Ord] => Y => Int = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> 1 From 24e3fa0fe7810b1bc5db0560e7ff7f047e504603 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Thu, 14 Nov 2024 14:14:03 +0100 Subject: [PATCH 16/19] More cleanup of poly context bound desugaring --- .../src/dotty/tools/dotc/ast/Desugar.scala | 133 ++++++++++-------- .../src/dotty/tools/dotc/typer/Typer.scala | 4 +- .../contextbounds-for-poly-functions.scala | 5 + 3 files changed, 83 insertions(+), 59 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 768d598987f0..0bf3ba71b84d 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -247,7 +247,7 @@ object desugar { * def f$default$2[T](x: Int) = x + "m" */ private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree = - addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor)) + addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef]) /** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that * get added to a buffer. @@ -309,10 +309,8 @@ object desugar { tdef1 end desugarContextBounds - private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef = - val DefDef(_, paramss, tpt, rhs) = meth + def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree = val evidenceParamBuf = mutable.ListBuffer[ValDef]() - var seenContextBounds: Int = 0 def freshName(unused: Tree) = seenContextBounds += 1 // Start at 1 like FreshNameCreator. @@ -322,7 +320,7 @@ object desugar { // parameters of the method since shadowing does not affect // implicit resolution in Scala 3. - val paramssNoContextBounds = + def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] = val iflag = paramss.lastOption.flatMap(_.headOption) match case Some(param) if param.mods.isOneOf(GivenOrImplicit) => param.mods.flags & GivenOrImplicit @@ -334,16 +332,29 @@ object desugar { tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss) }(identity) - rhs match - case MacroTree(call) => - cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased) - case _ => - addEvidenceParams( - cpy.DefDef(meth)( - name = normalizeName(meth, tpt).asTermName, - paramss = paramssNoContextBounds), - evidenceParamBuf.toList - ) + meth match + case meth @ DefDef(_, paramss, tpt, rhs) => + val newParamss = paramssNoContextBounds(paramss) + rhs match + case MacroTree(call) => + cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased) + case _ => + addEvidenceParams( + cpy.DefDef(meth)( + name = normalizeName(meth, tpt).asTermName, + paramss = newParamss + ), + evidenceParamBuf.toList + ) + case meth @ PolyFunction(tparams, fun) => + val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked + val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked + val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil) + val params = evidenceParamBuf.toList + val boundNames = getBoundNames(params, newParamss) + val recur = fitEvidenceParams(params, nme.apply, boundNames) + val (paramsFst, paramsSnd) = recur(newParamss) + functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -471,6 +482,55 @@ object desugar { case _ => (Nil, tree) + private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean = + vdef.tpt.existsSubTree: + case Ident(name: TermName) => names.contains(name) + case _ => false + + /** Fit evidence `params` into the `mparamss` parameter lists */ + private def fitEvidenceParams(params: List[ValDef], methName: Name, boundNames: Set[TermName])(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match + case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) => + (params :: Nil) -> mparamss + case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => + val normParams = + if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then + params.map: param => + val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit)) + param.withMods(param.mods.withFlags(normFlags)) + .showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar) + else params + ((normParams ++ mparams) :: Nil) -> Nil + case mparams :: mparamss1 => + val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1) + (mparams :: fst) -> snd + case Nil => + Nil -> (params :: Nil) + + /** Create a chain of possibly contextual functions from the parameter lists */ + private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match + case Nil => rhs + case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) => + val paramTpts = head.map(_.tpt) + val paramNames = head.map(_.name) + val paramsErased = head.map(_.mods.flags.is(Erased)) + makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span) + case ValDefs(head) :: rest => + Function(head, functionsOf(rest, rhs)) + case TypeDefs(head) :: rest => + PolyFunction(head, functionsOf(rest, rhs)) + case _ => + assert(false, i"unexpected paramss $paramss") + EmptyTree + + private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] = + var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names + for mparams <- paramss; mparam <- mparams do + mparam match + case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) => + boundNames += tparam.name.toTermName + case _ => + boundNames + /** Add all evidence parameters in `params` as implicit parameters to `meth`. * The position of the added parameters is determined as follows: * @@ -485,48 +545,9 @@ object desugar { private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef = if params.isEmpty then return meth - var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names - for mparams <- meth.paramss; mparam <- mparams do - mparam match - case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) => - boundNames += tparam.name.toTermName - case _ => + val boundNames = getBoundNames(params, meth.paramss) - def referencesBoundName(vdef: ValDef): Boolean = - vdef.tpt.existsSubTree: - case Ident(name: TermName) => boundNames.contains(name) - case _ => false - - def recur(mparamss: List[ParamClause]): (List[ParamClause], List[ParamClause]) = mparamss match - case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) => - (params :: Nil) -> mparamss - case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => - val normParams = - if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then - params.map: param => - val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit)) - param.withMods(param.mods.withFlags(normFlags)) - .showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar) - else params - ((normParams ++ mparams) :: Nil) -> Nil - case mparams :: mparamss1 => - val (fst, snd) = recur(mparamss1) - (mparams :: fst) -> snd - case Nil => - Nil -> (params :: Nil) - - def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match - case Nil => rhs - case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) => - val paramTpts = params.map(_.tpt) - val paramNames = params.map(_.name) - val paramsErased = params.map(_.mods.flags.is(Erased)) - makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span) - case ValDefs(head) :: rest => - Function(head, functionsOf(rest, rhs)) - case head :: _ => - assert(false, i"unexpected type parameters when adding evidence parameters to $meth") - EmptyTree + val recur = fitEvidenceParams(params, meth.name, boundNames) if meth.hasAttachment(PolyFunctionApply) then meth.removeAttachment(PolyFunctionApply) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f7610520f61c..bc4981ef11a4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) - else typedPolyFunctionValue(tree1, pt) + else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt) def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked @@ -1946,7 +1946,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) - defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) else val msg = @@ -1955,7 +1954,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer errorTree(EmptyTree, msg, tree.srcPos) case _ => val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) - defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) end typedPolyFunctionValue diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 6fadcda2b43e..13411a3ad769 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -86,3 +86,8 @@ val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true +val dependentCmp_1: [X: {Ord as ord}] => ord.T => Boolean = [X: {Ord as ord}] => (x: ord.T) => true + +val dependentCmp1: DependentCmp1 = [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true +val dependentCmp1_1: [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean = + [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true From f292ac54869e9b407a389ea75f891515ad31e07b Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Thu, 14 Nov 2024 16:15:08 +0100 Subject: [PATCH 17/19] Short circuit adding evidence params to poly functions, when there are no context bounds --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 0bf3ba71b84d..e8ebd77b0423 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -351,6 +351,7 @@ object desugar { val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil) val params = evidenceParamBuf.toList + if params.isEmpty then return meth val boundNames = getBoundNames(params, newParamss) val recur = fitEvidenceParams(params, nme.apply, boundNames) val (paramsFst, paramsSnd) = recur(newParamss) From f9db9fa0c963de738b39d8230cd30053db671f14 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 15 Nov 2024 08:35:03 +0100 Subject: [PATCH 18/19] Add a run test for poly context bounds; cleanup typer changes --- .../src/dotty/tools/dotc/ast/Desugar.scala | 1 - .../src/dotty/tools/dotc/typer/Typer.scala | 12 ++++---- .../contextbounds-for-poly-functions.check | 6 ++++ .../contextbounds-for-poly-functions.scala | 30 +++++++++++++++++++ 4 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 tests/run/contextbounds-for-poly-functions.check create mode 100644 tests/run/contextbounds-for-poly-functions.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e8ebd77b0423..6e54dee51c89 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -22,7 +22,6 @@ import parsing.Parsers import scala.annotation.internal.sharable import scala.annotation.threadUnsafe -import dotty.tools.dotc.quoted.QuoteUtils.treeOwner object desugar { import untpd.* diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index bc4981ef11a4..d9b29e8c5f17 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1945,7 +1945,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val resultTpt = untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) - val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) + val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) typed(desugared, pt) else val msg = @@ -1953,7 +1953,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer |Expected type should be a polymorphic function with the same number of type and value parameters.""" errorTree(EmptyTree, msg, tree.srcPos) case _ => - val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) + val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) typed(desugared, pt) end typedPolyFunctionValue @@ -3581,17 +3581,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case xtree => typedUnnamed(xtree) val unsimplifiedType = result.tpe - val result1 = simplify(result, pt, locked) - result1.tpe.stripTypeVar match + simplify(result, pt, locked) + result.tpe.stripTypeVar match case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos) - case _ => result1 + case _ => result catch case ex: TypeError => handleTypeError(ex) } } /** Interpolate and simplify the type of the given tree. */ - protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = + protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied || tree.isDef // ... unless tree is a definition diff --git a/tests/run/contextbounds-for-poly-functions.check b/tests/run/contextbounds-for-poly-functions.check new file mode 100644 index 000000000000..2e7f62a3914f --- /dev/null +++ b/tests/run/contextbounds-for-poly-functions.check @@ -0,0 +1,6 @@ +42 +a string +Kate is 27 years old +42 and a string +a string and Kate is 27 years old +Kate is 27 years old and 42 diff --git a/tests/run/contextbounds-for-poly-functions.scala b/tests/run/contextbounds-for-poly-functions.scala new file mode 100644 index 000000000000..dcc974fce198 --- /dev/null +++ b/tests/run/contextbounds-for-poly-functions.scala @@ -0,0 +1,30 @@ +import scala.language.experimental.modularity +import scala.language.future + +trait Show[X]: + def show(x: X): String + +given Show[Int] with + def show(x: Int) = x.toString + +given Show[String] with + def show(x: String) = x + +case class Person(name: String, age: Int) + +given Show[Person] with + def show(x: Person) = s"${x.name} is ${x.age} years old" + +type Shower = [X: Show] => X => String +val shower: Shower = [X: {Show as show}] => (x: X) => show.show(x) + +type DoubleShower = [X: Show] => X => [Y: Show] => Y => String +val doubleShower: DoubleShower = [X: {Show as show1}] => (x: X) => [Y: {Show as show2}] => (y: Y) => s"${show1.show(x)} and ${show2.show(y)}" + +object Test extends App: + println(shower(42)) + println(shower("a string")) + println(shower(Person("Kate", 27))) + println(doubleShower(42)("a string")) + println(doubleShower("a string")(Person("Kate", 27))) + println(doubleShower(Person("Kate", 27))(42)) From 952eff71f4b2f49ca1870b4ef568cd8da97ee1e0 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 18 Nov 2024 16:49:08 +0100 Subject: [PATCH 19/19] Cleanup context bounds for poly functions implementation after review --- .../src/dotty/tools/dotc/ast/Desugar.scala | 50 +++++++++++++------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 6e54dee51c89..56c153498f87 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -350,11 +350,13 @@ object desugar { val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil) val params = evidenceParamBuf.toList - if params.isEmpty then return meth - val boundNames = getBoundNames(params, newParamss) - val recur = fitEvidenceParams(params, nme.apply, boundNames) - val (paramsFst, paramsSnd) = recur(newParamss) - functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) + if params.isEmpty then + meth + else + val boundNames = getBoundNames(params, newParamss) + val recur = fitEvidenceParams(params, nme.apply, boundNames) + val (paramsFst, paramsSnd) = recur(newParamss) + functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -487,8 +489,27 @@ object desugar { case Ident(name: TermName) => names.contains(name) case _ => false - /** Fit evidence `params` into the `mparamss` parameter lists */ - private def fitEvidenceParams(params: List[ValDef], methName: Name, boundNames: Set[TermName])(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match + /** Fit evidence `params` into the `mparamss` parameter lists, making sure + * that all parameters referencing `params` are after them. + * - for methods the final parameter lists are := result._1 ++ result._2 + * - for poly functions, each element of the pair contains at most one term + * parameter list + * + * @param params the evidence parameters list that should fit into `mparamss` + * @param methName the name of the method that `mparamss` belongs to + * @param boundNames the names of the evidence parameters + * @param mparamss the original parameter lists of the method + * @return a pair of parameter lists containing all parameter lists in a + * reference-correct order; make sure that `params` is always at the + * intersection of the pair elements; this is relevant, for poly functions + * where `mparamss` is guaranteed to have exectly one term parameter list, + * then each pair element will have at most one term parameter list + */ + private def fitEvidenceParams( + params: List[ValDef], + methName: Name, + boundNames: Set[TermName] + )(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) => (params :: Nil) -> mparamss case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => @@ -547,19 +568,20 @@ object desugar { val boundNames = getBoundNames(params, meth.paramss) - val recur = fitEvidenceParams(params, meth.name, boundNames) + val fitParams = fitEvidenceParams(params, meth.name, boundNames) - if meth.hasAttachment(PolyFunctionApply) then - meth.removeAttachment(PolyFunctionApply) - // for PolyFunctions we are limited to a single term param list, so we reuse the recur logic to compute the new parameter lists - // and then we add the other parameter lists as function types to the return type - val (paramsFst, paramsSnd) = recur(meth.paramss) + if meth.removeAttachment(PolyFunctionApply).isDefined then + // for PolyFunctions we are limited to a single term param list, so we + // reuse the fitEvidenceParams logic to compute the new parameter lists + // and then we add the other parameter lists as function types to the + // return type + val (paramsFst, paramsSnd) = fitParams(meth.paramss) if ctx.mode.is(Mode.Type) then cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) else cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) else - val (paramsFst, paramsSnd) = recur(meth.paramss) + val (paramsFst, paramsSnd) = fitParams(meth.paramss) cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd) end addEvidenceParams