Skip to content

Commit ef0e3ac

Browse files
committed
Require array element types to be sealed
1 parent 1065bd1 commit ef0e3ac

File tree

13 files changed

+109
-26
lines changed

13 files changed

+109
-26
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ extension (tp: Type)
206206
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
207207
case _ => false
208208

209+
def isSealed(using Context): Boolean = tp match
210+
case tp: TypeParamRef => tp.underlying.isSealed
211+
case tp: TypeBounds => tp.hi.hasAnnotation(defn.Caps_SealedAnnot)
212+
case tp: TypeRef => tp.symbol.is(Sealed) || tp.info.isSealed // TODO: drop symbol flag?
213+
case _ => false
214+
209215
/** Drop @retains annotations everywhere */
210216
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
211217
val tm = new TypeMap:

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ object CaptureSet:
872872
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
873873
if variance > 0 || isExact then upper
874874
else if variance < 0 then CaptureSet.empty
875+
else if ctx.mode.is(Mode.Printing) then upper
875876
else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting")
876877

877878
/** Apply `f` to each element in `xs`, and join result sets with `++` */

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ object CheckCaptures:
148148
val check = new TypeTraverser:
149149

150150
extension (tparam: Symbol) def isParametricIn(carrier: Symbol): Boolean =
151-
val encl = carrier.owner.enclosingMethodOrClass
151+
val encl = carrier.maybeOwner.enclosingMethodOrClass
152152
if encl.isClass then tparam.isParametricIn(encl)
153153
else
154154
def recur(encl: Symbol): Boolean =
@@ -160,11 +160,9 @@ object CheckCaptures:
160160
def traverse(t: Type) =
161161
t.dealiasKeepAnnots match
162162
case t: TypeRef =>
163-
capt.println(i"disallow $t, $tp, $what, ${t.symbol.is(Sealed)}")
163+
capt.println(i"disallow $t, $tp, $what, ${t.isSealed}")
164164
t.info match
165-
case TypeBounds(_, hi)
166-
if !t.symbol.is(Sealed) && !hi.hasAnnotation(defn.Caps_SealedAnnot)
167-
&& !t.symbol.isParametricIn(carrier) =>
165+
case TypeBounds(_, hi) if !t.isSealed && !t.symbol.isParametricIn(carrier) =>
168166
if hi.isAny then
169167
report.error(
170168
em"""$what cannot $have $tp since
@@ -543,8 +541,8 @@ class CheckCaptures extends Recheck, SymTransformer:
543541
val TypeApply(fn, args) = tree
544542
val polyType = atPhase(thisPhase.prev):
545543
fn.tpe.widen.asInstanceOf[TypeLambda]
546-
for case (arg: TypeTree, pinfo, pname) <- args.lazyZip(polyType.paramInfos).lazyZip((polyType.paramNames)) do
547-
if pinfo.bounds.hi.hasAnnotation(defn.Caps_SealedAnnot) then
544+
for case (arg: TypeTree, formal, pname) <- args.lazyZip(polyType.paramRefs).lazyZip((polyType.paramNames)) do
545+
if formal.isSealed then
548546
def where = if fn.symbol.exists then i" in an argument of ${fn.symbol}" else ""
549547
disallowRootCapabilitiesIn(arg.knownType, fn.symbol,
550548
i"Sealed type variable $pname", "be instantiated to",
@@ -1315,6 +1313,23 @@ class CheckCaptures extends Recheck, SymTransformer:
13151313
traverseChildren(tp)
13161314
check.traverse(info)
13171315

1316+
def checkArraysAreSealedIn(tp: Type, pos: SrcPos)(using Context): Unit =
1317+
val check = new TypeTraverser:
1318+
def traverse(t: Type): Unit =
1319+
t match
1320+
case AppliedType(tycon, arg :: Nil) if tycon.typeSymbol == defn.ArrayClass =>
1321+
if !(pos.span.isSynthetic && ctx.reporter.errorsReported) then
1322+
CheckCaptures.disallowRootCapabilitiesIn(arg, NoSymbol,
1323+
"Array", "have element type",
1324+
"Since arrays are mutable, they have to be treated like variables,\nso their element type must be sealed.",
1325+
pos)
1326+
traverseChildren(t)
1327+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
1328+
traverse(rinfo)
1329+
case _ =>
1330+
traverseChildren(t)
1331+
check.traverse(tp)
1332+
13181333
/** Perform the following kinds of checks
13191334
* - Check all explicitly written capturing types for well-formedness using `checkWellFormedPost`.
13201335
* - Check that arguments of TypeApplys and AppliedTypes conform to their bounds.
@@ -1340,6 +1355,8 @@ class CheckCaptures extends Recheck, SymTransformer:
13401355
case _ =>
13411356
case _: ValOrDefDef | _: TypeDef =>
13421357
checkNoLocalRootIn(tree.symbol, tree.symbol.info, tree.symbol.srcPos)
1358+
case tree: TypeTree =>
1359+
checkArraysAreSealedIn(tree.tpe, tree.srcPos)
13431360
case _ =>
13441361
end check
13451362
end checker

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,9 @@ abstract class Recheck extends Phase, SymTransformer:
596596

597597
/** Show tree with rechecked types instead of the types stored in the `.tpe` field */
598598
override def show(tree: untpd.Tree)(using Context): String =
599-
atPhase(thisPhase) {
600-
super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree]))
601-
}
599+
atPhase(thisPhase):
600+
withMode(Mode.Printing):
601+
super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree]))
602602
end Recheck
603603

604604
/** A class that can be used to test basic rechecking without any customaization */
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- Error: tests/neg-custom-args/captures/buffers.scala:11:6 ------------------------------------------------------------
2+
11 | var elems: Array[A] = new Array[A](10) // error // error
3+
| ^
4+
| mutable variable elems cannot have type Array[A] since
5+
| that type refers to the type variable A, which is not sealed.
6+
-- Error: tests/neg-custom-args/captures/buffers.scala:16:38 -----------------------------------------------------------
7+
16 | def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error
8+
| ^^^^^^^^^^^
9+
| Sealed type variable A cannot be instantiated to box A^? since
10+
| that type refers to the type variable A, which is not sealed.
11+
| This is often caused by a local capability in an argument of constructor ArrayBuffer
12+
| leaking as part of its result.
13+
-- Error: tests/neg-custom-args/captures/buffers.scala:11:13 -----------------------------------------------------------
14+
11 | var elems: Array[A] = new Array[A](10) // error // error
15+
| ^^^^^^^^
16+
| Array cannot have element type A since
17+
| that type refers to the type variable A, which is not sealed.
18+
| Since arrays are mutable, they have to be treated like variables,
19+
| so their element type must be sealed.
20+
-- Error: tests/neg-custom-args/captures/buffers.scala:22:9 ------------------------------------------------------------
21+
22 | val x: Array[A] = new Array[A](10) // error
22+
| ^^^^^^^^
23+
| Array cannot have element type A since
24+
| that type refers to the type variable A, which is not sealed.
25+
| Since arrays are mutable, they have to be treated like variables,
26+
| so their element type must be sealed.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import reflect.ClassTag
2+
3+
class Buffer[A]
4+
5+
class ArrayBuffer[sealed A: ClassTag] extends Buffer[A]:
6+
var elems: Array[A] = new Array[A](10)
7+
def add(x: A): this.type = ???
8+
def at(i: Int): A = ???
9+
10+
class ArrayBufferBAD[A: ClassTag] extends Buffer[A]:
11+
var elems: Array[A] = new Array[A](10) // error // error
12+
def add(x: A): this.type = ???
13+
def at(i: Int): A = ???
14+
15+
object ArrayBuffer:
16+
def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error
17+
elems = xs.toArray
18+
def apply[sealed A: ClassTag](xs: A*) = new ArrayBuffer:
19+
elems = xs.toArray // ok
20+
21+
class EncapsArray[A: ClassTag]:
22+
val x: Array[A] = new Array[A](10) // error
23+
24+
25+
26+
27+
28+
29+
30+

tests/pos-special/stdlib/collection/IterableOnce.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext
165165
def toBuffer[sealed B >: A]: mutable.Buffer[B] = mutable.ArrayBuffer.from(it)
166166

167167
@deprecated("Use .iterator.toArray", "2.13.0")
168-
def toArray[B >: A: ClassTag]: Array[B] = it match {
168+
def toArray[sealed B >: A: ClassTag]: Array[B] = it match {
169169
case it: Iterable[B] => it.toArray[B]
170170
case _ => it.iterator.toArray[B]
171171
}
@@ -272,10 +272,11 @@ object IterableOnce {
272272
math.max(math.min(math.min(len, srcLen), destLen - start), 0)
273273

274274
/** Calls `copyToArray` on the given collection, regardless of whether or not it is an `Iterable`. */
275-
@inline private[collection] def copyElemsToArray[A, B >: A](elems: IterableOnce[A]^,
276-
xs: Array[B],
277-
start: Int = 0,
278-
len: Int = Int.MaxValue): Int =
275+
@inline private[collection] def copyElemsToArray[A, sealed B >: A](
276+
elems: IterableOnce[A]^,
277+
xs: Array[B],
278+
start: Int = 0,
279+
len: Int = Int.MaxValue): Int =
279280
elems match {
280281
case src: Iterable[A] => src.copyToArray[B](xs, start, len)
281282
case src => src.iterator.copyToArray[B](xs, start, len)
@@ -889,7 +890,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
889890
* @note Reuse: $consumesIterator
890891
*/
891892
@deprecatedOverriding("This should always forward to the 3-arg version of this method", since = "2.13.4")
892-
def copyToArray[B >: A](xs: Array[B]): Int = copyToArray(xs, 0, Int.MaxValue)
893+
def copyToArray[sealed B >: A](xs: Array[B]): Int = copyToArray(xs, 0, Int.MaxValue)
893894

894895
/** Copy elements to an array, returning the number of elements written.
895896
*
@@ -906,7 +907,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
906907
* @note Reuse: $consumesIterator
907908
*/
908909
@deprecatedOverriding("This should always forward to the 3-arg version of this method", since = "2.13.4")
909-
def copyToArray[B >: A](xs: Array[B], start: Int): Int = copyToArray(xs, start, Int.MaxValue)
910+
def copyToArray[sealed B >: A](xs: Array[B], start: Int): Int = copyToArray(xs, start, Int.MaxValue)
910911

911912
/** Copy elements to an array, returning the number of elements written.
912913
*
@@ -923,7 +924,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
923924
*
924925
* @note Reuse: $consumesIterator
925926
*/
926-
def copyToArray[B >: A](xs: Array[B], start: Int, len: Int): Int = {
927+
def copyToArray[sealed B >: A](xs: Array[B], start: Int, len: Int): Int = {
927928
val it = iterator
928929
var i = start
929930
val end = start + math.min(len, xs.length - start)
@@ -1318,7 +1319,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
13181319
*
13191320
* Implementation note: DO NOT call [[Array.from]] from this method.
13201321
*/
1321-
def toArray[B >: A: ClassTag]: Array[B] =
1322+
def toArray[sealed B >: A: ClassTag]: Array[B] =
13221323
if (knownSize >= 0) {
13231324
val destination = new Array[B](knownSize)
13241325
copyToArray(destination, 0)

tests/pos-special/stdlib/collection/Iterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
259259
}
260260
// segment must have data, and must be complete unless they allow partial
261261
val ok = index > 0 && (partial || index == size)
262-
if (ok) buffer = builder.result().asInstanceOf[Array[B]]
262+
if (ok) buffer = builder.result().asInstanceOf[Array[B @uncheckedCaptures]]
263263
else prev = null
264264
ok
265265
}

tests/pos-special/stdlib/collection/SeqView.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package collection
1616
import scala.annotation.nowarn
1717
import language.experimental.captureChecking
1818
import caps.unsafe.unsafeAssumePure
19+
import scala.annotation.unchecked.uncheckedCaptures
1920

2021
/** !!! Scala 2 difference: Need intermediate trait SeqViewOps to collect the
2122
* necessary functionality over which SeqViews are defined, and at the same
@@ -195,7 +196,7 @@ object SeqView {
195196
// contains items of another type, we'd get a CCE anyway)
196197
// - the cast doesn't actually do anything in the runtime because the
197198
// type of A is not known and Array[_] is Array[AnyRef]
198-
immutable.ArraySeq.unsafeWrapArray(arr.asInstanceOf[Array[A]])
199+
immutable.ArraySeq.unsafeWrapArray(arr.asInstanceOf[Array[A @uncheckedCaptures]])
199200
}
200201
}
201202
evaluated = true

tests/pos-special/stdlib/collection/StringOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ final class StringOps(private val s: String) extends AnyVal {
964964
else if (s.equalsIgnoreCase("false")) false
965965
else throw new IllegalArgumentException("For input string: \""+s+"\"")
966966

967-
def toArray[B >: Char](implicit tag: ClassTag[B]): Array[B] =
967+
def toArray[sealed B >: Char](implicit tag: ClassTag[B]): Array[B] =
968968
if (tag == ClassTag.Char) s.toCharArray.asInstanceOf[Array[B]]
969969
else new WrappedString(s).toArray[B]
970970

0 commit comments

Comments
 (0)