diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index ac6bf0252e47..a49bd9f79351 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -16,7 +16,7 @@ import typer.ErrorReporting.{Addenda, err} import typer.ProtoTypes.{AnySelectionProto, LhsProto} import util.{SimpleIdentitySet, EqHashMap, EqHashSet, SrcPos, Property} import transform.SymUtils.* -import transform.{Recheck, PreRecheck} +import transform.{Recheck, PreRecheck, CapturedVars} import Recheck.* import scala.collection.mutable import CaptureSet.{withCaptureSetsExplained, IdempotentCaptRefMap, CompareResult} @@ -149,15 +149,25 @@ object CheckCaptures: private val seen = new EqHashSet[TypeRef] + /** Check that there is at least one method containing carrier and defined + * in the scope of tparam. E.g. this is OK: + * def f[T] = { ... var x: T ... } + * So is this: + * class C[T] { def f() = { class D { var x: T }}} + * But this is not OK: + * class C[T] { object o { var x: T }} + */ extension (tparam: Symbol) def isParametricIn(carrier: Symbol): Boolean = - val encl = carrier.maybeOwner.enclosingMethodOrClass - if encl.isClass then tparam.isParametricIn(encl) - else - def recur(encl: Symbol): Boolean = - if tparam.owner == encl then true - else if encl.isStatic || !encl.exists then false - else recur(encl.owner.enclosingMethodOrClass) - recur(encl) + carrier.exists && { + val encl = carrier.owner.enclosingMethodOrClass + if encl.isClass then tparam.isParametricIn(encl) + else + def recur(encl: Symbol): Boolean = + if tparam.owner == encl then true + else if encl.isStatic || !encl.exists then false + else recur(encl.owner.enclosingMethodOrClass) + recur(encl) + } def traverse(t: Type) = t.dealiasKeepAnnots match @@ -168,9 +178,12 @@ object CheckCaptures: t.info match case TypeBounds(_, hi) if !t.isSealed && !t.symbol.isParametricIn(carrier) => if hi.isAny then + val detailStr = + if t eq tp then "variable" + else i"refers to the type variable $t, which" report.error( em"""$what cannot $have $tp since - |that type refers to the type variable $t, which is not sealed. + |that type $detailStr is not sealed. |$addendum""", pos) else @@ -549,7 +562,7 @@ class CheckCaptures extends Recheck, SymTransformer: for case (arg: TypeTree, formal, pname) <- args.lazyZip(polyType.paramRefs).lazyZip((polyType.paramNames)) do if formal.isSealed then def where = if fn.symbol.exists then i" in an argument of ${fn.symbol}" else "" - disallowRootCapabilitiesIn(arg.knownType, fn.symbol, + disallowRootCapabilitiesIn(arg.knownType, NoSymbol, i"Sealed type variable $pname", "be instantiated to", i"This is often caused by a local capability$where\nleaking as part of its result.", tree.srcPos) @@ -590,13 +603,58 @@ class CheckCaptures extends Recheck, SymTransformer: openClosures = openClosures.tail end recheckClosureBlock + /** Maps mutable variables to the symbols that capture them (in the + * CheckCaptures sense, i.e. symbol is referred to from a different method + * than the one it is defined in). + */ + private val capturedBy = util.HashMap[Symbol, Symbol]() + + /** Maps anonymous functions appearing as function arguments to + * the function that is called. + */ + private val anonFunCallee = util.HashMap[Symbol, Symbol]() + + /** Populates `capturedBy` and `anonFunCallee`. Called by `checkUnit`. + */ + private def collectCapturedMutVars(using Context) = new TreeTraverser: + def traverse(tree: Tree)(using Context) = tree match + case id: Ident => + val sym = id.symbol + if sym.is(Mutable, butNot = Method) && sym.owner.isTerm then + val enclMeth = ctx.owner.enclosingMethod + if sym.enclosingMethod != enclMeth then + capturedBy(sym) = enclMeth + case Apply(fn, args) => + for case closureDef(mdef) <- args do + anonFunCallee(mdef.symbol) = fn.symbol + traverseChildren(tree) + case Inlined(_, bindings, expansion) => + traverse(bindings) + traverse(expansion) + case mdef: DefDef => + if !mdef.symbol.isInlineMethod then traverseChildren(tree) + case _ => + traverseChildren(tree) + override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type = try if sym.is(Module) then sym.info // Modules are checked by checking the module class else if sym.is(Mutable) && !sym.hasAnnotation(defn.UncheckedCapturesAnnot) then - disallowRootCapabilitiesIn(tree.tpt.knownType, sym, - i"mutable $sym", "have type", "", sym.srcPos) + val (carrier, addendum) = capturedBy.get(sym) match + case Some(encl) => + val enclStr = + if encl.isAnonymousFunction then + val location = anonFunCallee.get(encl) match + case Some(meth) if meth.exists => i" argument in a call to $meth" + case _ => "" + s"an anonymous function$location" + else encl.show + (NoSymbol, i"\nNote that $sym does not count as local since it is captured by $enclStr") + case _ => + (sym, "") + disallowRootCapabilitiesIn( + tree.tpt.knownType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos) checkInferredResult(super.recheckValDef(tree, sym), tree) finally if !sym.is(Param) then @@ -1170,11 +1228,12 @@ class CheckCaptures extends Recheck, SymTransformer: private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup] override def checkUnit(unit: CompilationUnit)(using Context): Unit = - setup.setupUnit(ctx.compilationUnit.tpdTree, completeDef) + setup.setupUnit(unit.tpdTree, completeDef) + collectCapturedMutVars.traverse(unit.tpdTree) if ctx.settings.YccPrintSetup.value then val echoHeader = "[[syntax tree at end of cc setup]]" - val treeString = show(ctx.compilationUnit.tpdTree) + val treeString = show(unit.tpdTree) report.echo(s"$echoHeader\n$treeString\n") withCaptureSetsExplained: diff --git a/tests/neg-custom-args/captures/buffers.check b/tests/neg-custom-args/captures/buffers.check index cdb7baa852fb..07acea3c48e3 100644 --- a/tests/neg-custom-args/captures/buffers.check +++ b/tests/neg-custom-args/captures/buffers.check @@ -1,7 +1,7 @@ -- Error: tests/neg-custom-args/captures/buffers.scala:11:6 ------------------------------------------------------------ 11 | var elems: Array[A] = new Array[A](10) // error // error | ^ - | mutable variable elems cannot have type Array[A] since + | Mutable variable elems cannot have type Array[A] since | that type refers to the type variable A, which is not sealed. -- Error: tests/neg-custom-args/captures/buffers.scala:16:38 ----------------------------------------------------------- 16 | def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error @@ -14,13 +14,13 @@ 11 | var elems: Array[A] = new Array[A](10) // error // error | ^^^^^^^^ | Array cannot have element type A since - | that type refers to the type variable A, which is not sealed. + | that type variable is not sealed. | Since arrays are mutable, they have to be treated like variables, | so their element type must be sealed. -- Error: tests/neg-custom-args/captures/buffers.scala:22:9 ------------------------------------------------------------ 22 | val x: Array[A] = new Array[A](10) // error | ^^^^^^^^ | Array cannot have element type A since - | that type refers to the type variable A, which is not sealed. + | that type variable is not sealed. | Since arrays are mutable, they have to be treated like variables, | so their element type must be sealed. diff --git a/tests/neg-custom-args/captures/levels.check b/tests/neg-custom-args/captures/levels.check index f91f90fb652f..c0cc7f0a759c 100644 --- a/tests/neg-custom-args/captures/levels.check +++ b/tests/neg-custom-args/captures/levels.check @@ -1,8 +1,8 @@ -- Error: tests/neg-custom-args/captures/levels.scala:6:16 ------------------------------------------------------------- 6 | private var v: T = init // error | ^ - | mutable variable v cannot have type T since - | that type refers to the type variable T, which is not sealed. + | Mutable variable v cannot have type T since + | that type variable is not sealed. -- Error: tests/neg-custom-args/captures/levels.scala:17:13 ------------------------------------------------------------ 17 | val _ = Ref[String => String]((x: String) => x) // error | ^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/neg-custom-args/captures/sealed-leaks.check b/tests/neg-custom-args/captures/sealed-leaks.check new file mode 100644 index 000000000000..f7098eba32b6 --- /dev/null +++ b/tests/neg-custom-args/captures/sealed-leaks.check @@ -0,0 +1,50 @@ +-- [E129] Potential Issue Warning: tests/neg-custom-args/captures/sealed-leaks.scala:31:6 ------------------------------ +31 | () + | ^^ + | A pure expression does nothing in statement position + | + | longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:12:27 ------------------------------------------------------ +12 | val later2 = usingLogFile[(() => Unit) | Null] { f => () => f.write(0) } // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Sealed type variable T cannot be instantiated to (() => Unit) | Null since + | that type captures the root capability `cap`. + | This is often caused by a local capability in an argument of method usingLogFile + | leaking as part of its result. +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/sealed-leaks.scala:19:26 --------------------------------- +19 | usingLogFile { f => x = f } // error + | ^ + | Found: (f : java.io.FileOutputStream^) + | Required: (java.io.FileOutputStream | Null)^{cap[Test2]} + | + | longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:30:10 ------------------------------------------------------ +30 | var x: T = y // error + | ^ + | Mutable variable x cannot have type T since + | that type variable is not sealed. +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:39:8 ------------------------------------------------------- +39 | var x: T = y // error + | ^ + | Mutable variable x cannot have type T since + | that type variable is not sealed. + | + | Note that variable x does not count as local since it is captured by an anonymous function +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:43:8 ------------------------------------------------------- +43 | var x: T = y // error + | ^ + |Mutable variable x cannot have type T since + |that type variable is not sealed. + | + |Note that variable x does not count as local since it is captured by an anonymous function argument in a call to method identity +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:47:8 ------------------------------------------------------- +47 | var x: T = y // error + | ^ + | Mutable variable x cannot have type T since + | that type variable is not sealed. + | + | Note that variable x does not count as local since it is captured by method foo +-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:11:14 ------------------------------------------------------ +11 | val later = usingLogFile { f => () => f.write(0) } // error + | ^^^^^^^^^^^^ + | local reference f leaks into outer capture set of type parameter T of method usingLogFile diff --git a/tests/neg-custom-args/captures/sealed-leaks.scala b/tests/neg-custom-args/captures/sealed-leaks.scala index a7acf77b5678..2555ba8a3e07 100644 --- a/tests/neg-custom-args/captures/sealed-leaks.scala +++ b/tests/neg-custom-args/captures/sealed-leaks.scala @@ -18,4 +18,34 @@ def Test2 = usingLogFile { f => x = f } // error - later() \ No newline at end of file + later() + +def Test3 = + def f[T](y: T) = + var x: T = y + () + + class C[T](y: T): + object o: + var x: T = y // error + () + + class C2[T](y: T): + def f = + var x: T = y // ok + () + + def g1[T](y: T): T => Unit = + var x: T = y // error + y => x = y + + def g2[T](y: T): T => Unit = + var x: T = y // error + identity(y => x = y) + + def g3[T](y: T): Unit = + var x: T = y // error + def foo = + x = y + () + diff --git a/tests/pos-special/stdlib/collection/Iterator.scala b/tests/pos-special/stdlib/collection/Iterator.scala index 993e2fc0cfea..90fd387069b0 100644 --- a/tests/pos-special/stdlib/collection/Iterator.scala +++ b/tests/pos-special/stdlib/collection/Iterator.scala @@ -868,7 +868,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite */ def duplicate: (Iterator[A]^{this}, Iterator[A]^{this}) = { val gap = new scala.collection.mutable.Queue[A @uncheckedCaptures] - var ahead: Iterator[A] = null + var ahead: Iterator[A @uncheckedCaptures] = null // ahead is captured by Partner, so A is not recognized as parametric class Partner extends AbstractIterator[A] { override def knownSize: Int = self.synchronized { val thisSize = self.knownSize diff --git a/tests/pos-special/stdlib/collection/immutable/LazyListIterable.scala b/tests/pos-special/stdlib/collection/immutable/LazyListIterable.scala index 8d804bad13de..5684130b6048 100644 --- a/tests/pos-special/stdlib/collection/immutable/LazyListIterable.scala +++ b/tests/pos-special/stdlib/collection/immutable/LazyListIterable.scala @@ -852,9 +852,9 @@ final class LazyListIterable[+A] private(private[this] var lazyState: () => Lazy else if (!isEmpty) { b.append(head) var cursor = this - @inline def appendCursorElement(): Unit = b.append(sep).append(cursor.head) + inline def appendCursorElement(): Unit = b.append(sep).append(cursor.head) var scout = tail - @inline def scoutNonEmpty: Boolean = scout.stateDefined && !scout.isEmpty + inline def scoutNonEmpty: Boolean = scout.stateDefined && !scout.isEmpty if ((cursor ne scout) && (!scout.stateDefined || (cursor.state ne scout.state))) { cursor = scout if (scoutNonEmpty) { @@ -998,7 +998,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def filterImpl[A](ll: LazyListIterable[A]^, p: A => Boolean, isFlipped: Boolean): LazyListIterable[A]^{ll, p} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[filterImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { var elem: A = null.asInstanceOf[A] var found = false @@ -1015,7 +1015,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def collectImpl[A, B](ll: LazyListIterable[A]^, pf: PartialFunction[A, B]^): LazyListIterable[B]^{ll, pf} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[collectImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { val marker = Statics.pfMarker val toMarker = anyToMarker.asInstanceOf[A => B] // safe because Function1 is erased @@ -1034,9 +1034,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[flatMapImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { - var it: Iterator[B]^{ll, f} = null + var it: Iterator[B @uncheckedCaptures]^{ll, f} = null var itHasNext = false var rest = restRef // var rest = restRef.elem while (!itHasNext && !rest.isEmpty) { @@ -1058,7 +1058,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def dropImpl[A](ll: LazyListIterable[A]^, n: Int): LazyListIterable[A]^{ll} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[dropImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric var iRef = n // val iRef = new IntRef(n) newLL { var rest = restRef // var rest = restRef.elem @@ -1075,7 +1075,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def dropWhileImpl[A](ll: LazyListIterable[A]^, p: A => Boolean): LazyListIterable[A]^{ll, p} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[dropWhileImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { var rest = restRef // var rest = restRef.elem while (!rest.isEmpty && p(rest.head)) { @@ -1088,8 +1088,8 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { private def takeRightImpl[A](ll: LazyListIterable[A]^, n: Int): LazyListIterable[A]^{ll} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD - var restRef = ll // val restRef = new ObjectRef(ll) - var scoutRef = ll // val scoutRef = new ObjectRef(ll) + var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[takeRightImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric + var scoutRef: LazyListIterable[A @uncheckedCaptures]^{cap[takeRightImpl]} = ll // same situation var remainingRef = n // val remainingRef = new IntRef(n) newLL { var scout = scoutRef // var scout = scoutRef.elem diff --git a/tests/pos-special/stdlib/collection/immutable/TreeSeqMap.scala b/tests/pos-special/stdlib/collection/immutable/TreeSeqMap.scala index d7cceb54cca3..91233669e5ca 100644 --- a/tests/pos-special/stdlib/collection/immutable/TreeSeqMap.scala +++ b/tests/pos-special/stdlib/collection/immutable/TreeSeqMap.scala @@ -609,7 +609,7 @@ object TreeSeqMap extends MapFactory[TreeSeqMap] { } final def splitAt(n: Int): (Ordering[T], Ordering[T]) = { - var rear = Ordering.empty[T] + var rear: Ordering[T @uncheckedCaptures] = Ordering.empty[T] var i = n (modifyOrRemove { (o, v) => i -= 1 diff --git a/tests/pos-special/stdlib/collection/immutable/Vector.scala b/tests/pos-special/stdlib/collection/immutable/Vector.scala index 1bde30406fd9..d9d33add512d 100644 --- a/tests/pos-special/stdlib/collection/immutable/Vector.scala +++ b/tests/pos-special/stdlib/collection/immutable/Vector.scala @@ -229,7 +229,7 @@ sealed abstract class Vector[+A] private[immutable] (private[immutable] final va // k >= 0, k = suffix.knownSize val tinyAppendLimit = 4 + vectorSliceCount if (k < tinyAppendLimit) { - var v: Vector[B] = this + var v: Vector[B @uncheckedCaptures] = this suffix match { case it: Iterable[_] => it.asInstanceOf[Iterable[B]].foreach(x => v = v.appended(x)) case _ => suffix.iterator.foreach(x => v = v.appended(x))