Skip to content

Commit

Permalink
Improve type inference for functions like fold
Browse files Browse the repository at this point in the history
When calling a fold with an accumulator like `Nil` or `List()` one used to have add
an explicit type ascription. This is now no longer necessary. When instantiating
type variables that occur invariantly in the expected type of a lambda, we now replace
covariant occurrences of `Nothing` in the (possibly widened) type of the accumulator
with fresh type variables.

The idea is that a fresh type variable in such places is always better than Nothing. For
module values such as `Nil` we widen to `List[<fresh var>]`. This does possibly cause a new
type error if the fold really wanted a `Nil` instance. But that case seems very rare,
so it looks like a good bet in general to do the widening.
  • Loading branch information
odersky committed Oct 28, 2023
1 parent 38559d7 commit bbf5579
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 48 deletions.
12 changes: 3 additions & 9 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Flags._
import config.Config
import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import transform.TypeUtils.isTransparent
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
Expand Down Expand Up @@ -566,13 +567,6 @@ trait ConstraintHandling {
inst
end approximation

private def isTransparent(tp: Type, traitOnly: Boolean)(using Context): Boolean = tp match
case AndType(tp1, tp2) =>
isTransparent(tp1, traitOnly) && isTransparent(tp2, traitOnly)
case _ =>
val cls = tp.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))

/** If `tp` is an intersection such that some operands are transparent trait instances
* and others are not, replace as many transparent trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
Expand All @@ -585,7 +579,7 @@ trait ConstraintHandling {
var dropped: List[Type] = List() // the types dropped so far, last one on top

def dropOneTransparentTrait(tp: Type): Type =
if isTransparent(tp, traitOnly = true) && !kept.contains(tp) then
if tp.isTransparent(traitOnly = true) && !kept.contains(tp) then
dropped = tp :: dropped
defn.AnyType
else tp match
Expand Down Expand Up @@ -658,7 +652,7 @@ trait ConstraintHandling {
def widenOr(tp: Type) =
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && !isTransparent(tpw, traitOnly = false) && (tpw <:< bound) then tpw else tp
if (tpw ne tp) && !tpw.isTransparent() && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
Expand Down
25 changes: 12 additions & 13 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4895,19 +4895,22 @@ object Types {
/** Instantiate variable with given type */
def instantiateWith(tp: Type)(using Context): Type = {
assert(tp ne this, i"self instantiation of $origin, constraint = ${ctx.typerState.constraint}")
assert(!myInst.exists, i"$origin is already instantiated to $myInst but we attempted to instantiate it to $tp")
typr.println(i"instantiating $this with $tp")
if !myInst.exists then
typr.println(i"instantiating $this with $tp")

if Config.checkConstraintsSatisfiable then
assert(currentEntry.bounds.contains(tp),
i"$origin is constrained to be $currentEntry but attempted to instantiate it to $tp")
if Config.checkConstraintsSatisfiable then
assert(currentEntry.bounds.contains(tp),
i"$origin is constrained to be $currentEntry but attempted to instantiate it to $tp")

if ((ctx.typerState eq owningState.nn.get.uncheckedNN) && !TypeComparer.subtypeCheckInProgress)
setInst(tp)
ctx.typerState.constraint = ctx.typerState.constraint.replace(origin, tp)
if ((ctx.typerState eq owningState.nn.get.uncheckedNN) && !TypeComparer.subtypeCheckInProgress)
setInst(tp)
ctx.typerState.constraint = ctx.typerState.constraint.replace(origin, tp)
tp
}

def typeToInstantiateWith(fromBelow: Boolean)(using Context): Type =
TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)

/** Instantiate variable from the constraints over its `origin`.
* If `fromBelow` is true, the variable is instantiated to the lub
* of its lower bounds in the current constraint; otherwise it is
Expand All @@ -4916,11 +4919,7 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
instantiateWith(tp)
instantiateWith(typeToInstantiateWith(fromBelow))

/** Widen unions when instantiating this variable in the current context? */
def widenUnions(using Context): Boolean = !ctx.typerState.constraint.isHard(this)
Expand Down
18 changes: 12 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@ package dotty.tools
package dotc
package transform

import core._
import core.*
import TypeErasure.ErasedValueType
import Types._
import Contexts._
import Symbols._
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name

import dotty.tools.dotc.core.Decorators.*

object TypeUtils {
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
Expand Down Expand Up @@ -98,5 +94,15 @@ object TypeUtils {
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false

/** Is this a type deriving only from transparent classes?
* @param traitOnly if true, all class symbols must be transparent traits
*/
def isTransparent(traitOnly: Boolean = false)(using Context): Boolean = self match
case AndType(tp1, tp2) =>
tp1.isTransparent(traitOnly) && tp2.isTransparent(traitOnly)
case _ =>
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
}
}
101 changes: 82 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ProtoTypes._
import NameKinds.UniqueName
import util.Spans._
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
import transform.TypeUtils.isTransparent
import Decorators._
import config.Printers.{gadts, typr}
import annotation.tailrec
Expand Down Expand Up @@ -60,7 +61,9 @@ object Inferencing {
def instantiateSelected(tp: Type, tvars: List[Type])(using Context): Unit =
if (tvars.nonEmpty)
IsFullyDefinedAccumulator(
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
new ForceDegree.Value(IfBottom.flip):
override def appliesTo(tvar: TypeVar) = tvars.contains(tvar),
minimizeSelected = true
).process(tp)

/** Instantiate any type variables in `tp` whose bounds contain a reference to
Expand Down Expand Up @@ -154,15 +157,58 @@ object Inferencing {
* their lower bound. Record whether successful.
* 2nd Phase: If first phase was successful, instantiate all remaining type variables
* to their upper bound.
*
* Instance types can be improved by replacing covariant occurrences of Nothing
* with fresh type variables, if `force` allows this in its `canImprove` implementation.
*/
private class IsFullyDefinedAccumulator(force: ForceDegree.Value, minimizeSelected: Boolean = false)
(using Context) extends TypeAccumulator[Boolean] {

private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
/** Replace toplevel-covariant occurrences (i.e. covariant without double flips)
* of Nothing by fresh type variables.
* For singleton types and references to module classes: try to
* improve the widened type. For module classes, the widened type
* is the intersection of all its non-transparent parent types.
*/
private def improve(tvar: TypeVar) = new TypeMap:
def apply(t: Type) = trace(i"improve $t", show = true):
def tryWidened(widened: Type): Type =
val improved = apply(widened)
if improved ne widened then improved else mapOver(t)
if variance > 0 then
t match
case t: TypeRef =>
if t.symbol == defn.NothingClass then
newTypeVar(TypeBounds.empty, nestingLevel = tvar.nestingLevel)
else if t.symbol.is(ModuleClass) then
tryWidened(t.parents.filter(!_.isTransparent())
.foldLeft(defn.AnyType: Type)(TypeComparer.andType(_, _)))
else
mapOver(t)
case t: TermRef =>
tryWidened(t.widen)
case _ =>
mapOver(t)
else t

/** Instantiate type variable with possibly improved computed instance type.
* @return true if variable was instantiated with improved type, which
* in this case should not be instantiated further, false otherwise.
*/
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Boolean =
if fromBelow && force.canImprove(tvar) then
val inst = tvar.typeToInstantiateWith(fromBelow = true)
if apply(true, inst) then
// need to recursively check before improving, since improving adds type vars
// which should not be instantiated at this point
val better = improve(tvar)(inst)
if better <:< TypeComparer.fullUpperBound(tvar.origin) then
typr.println(i"forced instantiation of invariant ${tvar.origin} = $inst, improved to $better")
tvar.instantiateWith(better)
return true
val inst = tvar.instantiate(fromBelow)
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
inst
}
false

private var toMaximize: List[TypeVar] = Nil

Expand All @@ -178,26 +224,27 @@ object Inferencing {
&& ctx.typerState.constraint.contains(tvar)
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
instantiate(tvar, fromBelow = false)
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
instantiate(tvar, fromBelow = direction < 0)
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
!fail && foldOver(x, tvar)
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
}
Expand Down Expand Up @@ -467,7 +514,7 @@ object Inferencing {
*
* we want to instantiate U to x.type right away. No need to wait further.
*/
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
Stats.record("variances")
val constraint = ctx.typerState.constraint

Expand Down Expand Up @@ -769,14 +816,30 @@ trait Inferencing { this: Typer =>
}

/** An enumeration controlling the degree of forcing in "is-fully-defined" checks. */
@sharable object ForceDegree {
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom):
override def toString = s"ForceDegree.Value(.., $ifBottom)"
val none: Value = new Value(_ => false, IfBottom.ok) { override def toString = "ForceDegree.none" }
val all: Value = new Value(_ => true, IfBottom.ok) { override def toString = "ForceDegree.all" }
val failBottom: Value = new Value(_ => true, IfBottom.fail) { override def toString = "ForceDegree.failBottom" }
val flipBottom: Value = new Value(_ => true, IfBottom.flip) { override def toString = "ForceDegree.flipBottom" }
}
@sharable object ForceDegree:
class Value(val ifBottom: IfBottom):

/** Does `tv` need to be instantiated? */
def appliesTo(tv: TypeVar): Boolean = true

/** Should we try to improve the computed instance type by replacing bottom types
* with fresh type variables?
*/
def canImprove(tv: TypeVar): Boolean = false

override def toString = s"ForceDegree.Value($ifBottom)"
end Value

val none: Value = new Value(IfBottom.ok):
override def appliesTo(tv: TypeVar) = false
override def toString = "ForceDegree.none"
val all: Value = new Value(IfBottom.ok):
override def toString = "ForceDegree.all"
val failBottom: Value = new Value(IfBottom.fail):
override def toString = "ForceDegree.failBottom"
val flipBottom: Value = new Value(IfBottom.flip):
override def toString = "ForceDegree.flipBottom"
end ForceDegree

enum IfBottom:
case ok, fail, flip
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1622,14 +1622,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _ =>

if desugared.isEmpty then
val forceDegree =
if pt.isValueType then
// Allow variables that appear invariantly in `pt` to be improved by mapping
// bottom types in their instance types to fresh type variables
new ForceDegree.Value(IfBottom.fail):
val tvmap = variances(pt)
override def canImprove(tvar: TypeVar) =
tvmap.computedVariance(tvar) == (0: Integer)
else
ForceDegree.failBottom

val inferredParams: List[untpd.ValDef] =
for ((param, i) <- params.zipWithIndex) yield
if (!param.tpt.isEmpty) param
else
val (formalBounds, isErased) = protoFormal(i)
val formal = formalBounds.loBound
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
val knownFormal = isFullyDefined(formal, forceDegree)
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
val paramType =
Expand Down
34 changes: 34 additions & 0 deletions tests/pos/folds.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

object Test:
extension [A](xs: List[A])
def foldl[B](acc: B)(f: (A, B) => B): B = ???

val xs = List(1, 2, 3)

val _ = xs.foldl(List())((y, ys) => y :: ys)

val _ = xs.foldl(Nil)((y, ys) => y :: ys)

def partition[a](xs: List[a], pred: a => Boolean): Tuple2[List[a], List[a]] = {
xs.foldRight/*[Tuple2[List[a], List[a]]]*/((List(), List())) {
(x, p) => if (pred (x)) (x :: p._1, p._2) else (p._1, x :: p._2)
}
}

def snoc[A](xs: List[A], x: A) = x :: xs

def reverse[A](xs: List[A]) =
xs.foldLeft(Nil)(snoc)

def reverse2[A](xs: List[A]) =
xs.foldLeft(List())(snoc)

val ys: Seq[Int] = xs
ys.foldLeft(Seq())((ys, y) => y +: ys)
ys.foldLeft(Nil)((ys, y) => y +: ys)

def dup[A](xs: List[A]) =
xs.foldRight(Nil)((x, xs) => x :: x :: xs)

def toSet[A](xs: Seq[A]) =
xs.foldLeft(Set.empty)(_ + _)

0 comments on commit bbf5579

Please sign in to comment.