Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check bounds in match type case bodies #17602

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1487,10 +1487,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
* Delegates to compareS if `tycon` is scala.compiletime.S. Otherwise, constant folds if possible.
*/
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else {
defn.isCompiletime_S(tp.tycon.typeSymbol) && compareS(tp, other, fromBelow)
|| {
val folded = tp.tryCompiletimeConstantFold
if (fromBelow) recur(other, folded) else recur(folded, other)
} || other.match {
case other: TypeRef if !fromBelow && other.symbol == defn.SingletonClass =>
tp.args.forall(arg => isSubType(arg, defn.SingletonType))
// Compile-time operations with singleton arguments are singletons
case _ => false
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ object TypeEval:
case tp: TypeProxy =>
val tp1 = tp.superType
if tp1.isStable then tp1.fixForEvaluation else tp
case AndType(tp1: ConstantType, tp2) if tp1.frozen_<:<(tp2) => tp1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AndType is supposed to be commutative, so if we do this we must also do the mirror:

Suggested change
case AndType(tp1: ConstantType, tp2) if tp1.frozen_<:<(tp2) => tp1
case AndType(tp1: ConstantType, tp2) if tp1.frozen_<:<(tp2) => tp1
case AndType(tp1, tp2: ConstantType) if tp2.frozen_<:<(tp1) => tp2

case tp => tp

def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match
Expand Down
13 changes: 10 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,19 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
// case x: (_: Tree[?])
case m @ MatchTypeTree(bounds, selector, cases) =>
// Analog to the case above for match types
def transformIgnoringBoundsCheck(x: CaseDef): CaseDef =
withMode(Mode.Pattern)(super.transform(x)).asInstanceOf[CaseDef]
def transformCase(x: CaseDef): CaseDef =
val gadtCtx = x.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this enough to ensure that, after PostTyper, the trees are well-kinded without resorting to GADT reasoning? I'm not familiar with what are the consequences of doing setGadtState.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not inserting any casts, if that's what you're asking. And I'm hoping we can handle things like:

class Foo
class Bar
class Res[X <: Bar]
type MT[X <: Foo | Bar] = X match
  case Foo => Unit
  case Bar => Res[X] // X vs bounds check <: Bar 

in part also because I've found adding a type intersection (for types) doesn't fix errors like adding a type casting (for terms) fixes errors...

Copy link
Member Author

@dwijnand dwijnand Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can either:

  1. disallow such reasoning, force users to insert type intersections, and then decide how much to work on the limitations of those intersections
  2. allow these things to require GADT reasoning
  3. deal with how to force types to honour bounds, and perhaps in a way that is accessible to users too
  4. return to the AssumeInfo tree idea

😭

case None => ctx
inContext(gadtCtx)(cpy.CaseDef(tree)(
withMode(Mode.Pattern)(transform(x.pat)),
transform(x.guard),
transform(x.body),
))
cpy.MatchTypeTree(tree)(
super.transform(bounds),
super.transform(selector),
cases.mapConserve(transformIgnoringBoundsCheck)
cases.mapConserve(transformCase)
)
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
superAcc.withInvalidCurrentClass(super.transform(tree))
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,10 @@ class Namer { typer: Typer =>

override final def typeSig(sym: Symbol): Type =
val tparamSyms = completerTypeParams(sym)(using ictx)
given ctx: Context = nestedCtx.nn
given ctx: Context = if tparamSyms.isEmpty then nestedCtx.nn else
given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds
ctx.gadtState.addToConstraint(tparamSyms)
ctx

def abstracted(tp: TypeBounds): TypeBounds =
HKTypeLambda.boundsFromParams(tparamSyms, tp)
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2005,13 +2005,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
if !ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef then
withMode(Mode.GadtConstraintInference) {
TypeComparer.constrainPatternType(pat1.tpe, selType)
selType match
case scr: TypeRef if ctx.gadt.contains(scr.symbol) => pat1.tpe match
case pat: TypeRef => scr <:< pat
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
}
val pat2 = indexPattern(cdef).transform(pat1)
var body1 = typedType(cdef.body, pt)
if !body1.isType then
assert(ctx.reporter.errorsReported)
body1 = TypeTree(errorType(em"<error: not a type>", cdef.srcPos))
else if ctx.gadt.isNarrowing then
pat2.putAttachment(InferredGadtConstraints, ctx.gadt)
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
}
caseRest(using ctx.fresh.setFreshGADTBounds.setNewScope)
Expand Down
4 changes: 2 additions & 2 deletions library/src/scala/Tuple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ object Tuple {
/** Converts a tuple `(T1, ..., Tn)` to `(F[T1], ..., F[Tn])` */
type Map[Tup <: Tuple, F[_ <: Union[Tup]]] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
case h *: t => F[h] *: Map[t, F]
case h *: t => F[h & Union[Tup]] *: Map[t, [x <: Union[t]] =>> F[x & Union[Tup]]]
}

/** Converts a tuple `(T1, ..., Tn)` to a flattened `(..F[T1], ..., ..F[Tn])` */
type FlatMap[Tup <: Tuple, F[_ <: Union[Tup]] <: Tuple] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
case h *: t => Concat[F[h], FlatMap[t, F]]
case h *: t => Concat[F[h & Union[Tup]], FlatMap[t, [x <: Union[t]] =>> F[x & Union[Tup]]]]
}

/** Filters out those members of the tuple for which the predicate `P` returns `false`.
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/i15272.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
case class Head[+NT, +From <: NT, +To <: NT] (from: From, to: To ) extends EdgeN[NT]
case class Cons[+NT, +From <: NT, +ToE <: EdgeN[NT]](from: From, to: ToE) extends EdgeN[NT]
final type InNodesTupleOf[NT, E <: EdgeN[NT]] <: Tuple = E match
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE]
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE] // error
case Head[nt,from ,to] => from *: EmptyTuple
def inNodesTuple[NT,E <: EdgeN[NT]](edge: E): InNodesTupleOf[NT,E] = edge match
case e: Cons[nt,from,toE] => e.from *: inNodesTuple[nt,toE](e.to) // error
case e: Head[nt,from,to] => e.from *: EmptyTuple
end EdgeN
end EdgeN
3 changes: 2 additions & 1 deletion tests/pos/10867.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
object Test {
// e.g inserts[z, (a, b)] =:= ((z, a, b), (a, z, b), (a, b, z))
type inserts[a, as <: Tuple] <: Tuple =
as match
case EmptyTuple => (a *: EmptyTuple) *: EmptyTuple
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t <: Tuple] =>> y *: t]
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t] =>> y *: (t & Tuple)]

type inserts2[a] =
[as <: Tuple] =>> inserts[a, as]
Expand Down
4 changes: 2 additions & 2 deletions tests/pos/13633.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object Sums extends App:
case false => A
case true => Inc[A]

type PlusLoop[A <: Tuple, B <: Tuple, O] <: Tuple = (A, B) match
type PlusLoop[A <: Tuple, B <: Tuple, O <: Boolean] <: Tuple = (A, B) match
case (EmptyTuple, EmptyTuple) =>
O match
case true => (true *: EmptyTuple)
Expand All @@ -47,4 +47,4 @@ object Sums extends App:
case (A, EmptyTuple) => IncT[A, O]
case (a *: as, b *: bs) =>
PlusTri[a, b, O] match
case (x, y) => y *: PlusLoop[as, bs, x]
case (x, y) => y *: PlusLoop[as, bs, x & Boolean]
6 changes: 3 additions & 3 deletions tests/pos/9239.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ object ABug:
type Zero = B0 :: Nil
type One = B1 :: Nil

type --[B <: Bin] =
type --[B <: Bin] <: Bin =
B match
case B1 :: d => B0 :: d
case B0 :: B1 :: Nil => B1 :: Nil
case B0 :: d => B1 :: --[d]

type ×[N <: Bin, M <: Bin] =
type ×[N <: Bin, M <: Bin] <: Bin =
(N, M) match
case (Zero, ?) => Zero

type ![N <: Bin] =
type ![N <: Bin] <: Bin =
N match
case Zero => One
case One => One
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/9890.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Test {

type TupleMap[Tup <: Tuple, Bound, F[_ <: Bound]] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
case h *: t => F[h] *: TupleMap[t, Bound, F]
case h *: t => F[h & Bound] *: TupleMap[t, Bound, F]
}
type TupleDedup[Tup <: Tuple, Mask] <: Tuple = Tup match {
case EmptyTuple => EmptyTuple
Expand Down
3 changes: 3 additions & 0 deletions tests/pos/Tuple.FlatMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type FlatMap[Tup <: Tuple, F[_] <: Tuple] <: Tuple = Tup match
case EmptyTuple => EmptyTuple
case h *: t => Tuple.Concat[F[h], FlatMap[t, F]]
53 changes: 53 additions & 0 deletions tests/pos/Tuple.Map.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
type Fold0[Tup <: Tuple, Z, F[_, _]] = Tup match
case EmptyTuple => Z
case h *: t => F[h, Fold0[t, Z, F]]


type Union0[T <: Tuple] = Fold0[T, Nothing, [x, y] =>> x | y]

type Union1[Tup <: Tuple] = Tup match
case EmptyTuple => Nothing
case h *: t => h | Union1[t]


import Tuple.Map as Map0

type Map1[Tup <: Tuple, F[_ <: Union0[Tup]]] <: Tuple = Tup match
case EmptyTuple => EmptyTuple
case h *: t => F[h & Union0[Tup]] *: Map1[t, [x <: Union0[t]] =>> F[x & Union0[Tup]]]

type Map2[Tup <: Tuple, F[_]] <: Tuple = Tup match
case EmptyTuple => EmptyTuple
case h *: t => F[h] *: Map2[t, F]

//type Map3[Tup <: Tuple, F[_ <: Union1[Tup]]] <: Tuple = Tup match
// case EmptyTuple => EmptyTuple
// case h *: t => F[h] *: Map3[t, F]

type Map4 [Tup <: Tuple, F[_ <: Union1[Tup]]] = Map4UB[Tup, F, Union1[Tup]]
type Map4UB[Tup <: Tuple, F[_ <: UB], UB] <: Tuple = Tup match
case EmptyTuple => EmptyTuple
case h *: t => F[h & UB] *: Map4UB[t, F, UB]

type Map5 [Tup <: Tuple, F[_ <: Union1[Tup]]] = Map5UB[Tup, Union1[Tup], F, Tup]
type Map5UB[Tup <: Tuple, UB, F[_ <: UB], Tup1 <: Tuple] <: Tuple = Tup1 match
case EmptyTuple => EmptyTuple
case h *: t => F[h & UB] *: Map5UB[Tup, UB, F, t]

trait Dt[T]
case class IBox[A <: Int](v: A)

class Test[H, T <: Tuple]:
//def t0 = { val x: Dt[H] *: Map0[T, Dt] = ???; val y: Map0[H *: T, Dt] = x }
def t1 = { val x: Dt[H] *: Map1[T, Dt] = ???; val y: Map1[H *: T, Dt] = x }
def t2 = { val x: Dt[H] *: Map2[T, Dt] = ???; val y: Map2[H *: T, Dt] = x }
//def t3 = { val x: Dt[H] *: Map3[T, Dt] = ???; val y: Map3[H *: T, Dt] = x }
//def t4 = { val x: Dt[H] *: Map4[T, Dt] = ???; val y: Map4[H *: T, Dt] = x }
//def t5 = { val x: Dt[H] *: Map5[T, Dt] = ???; val y: Map5[H *: T, Dt] = x }

def i0 = { val x: Map0[(1, 2), IBox] = (IBox(1), IBox(2)) }
def i1 = { val x: Map1[(1, 2), IBox] = (IBox(1), IBox(2)) }
//def i2 = { val x: Map2[(1, 2), IBox] = (IBox(1), IBox(2)) }
//def i3 = { val x: Map3[(1, 2), IBox] = (IBox(1), IBox(2)) }
def i4 = { val x: Map4[(1, 2), IBox] = (IBox(1), IBox(2)) }
def i5 = { val x: Map5[(1, 2), IBox] = (IBox(1), IBox(2)) }
8 changes: 6 additions & 2 deletions tests/pos/i15926.extract.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ final case class Succ[+N <: Nat]() extends Nat

final case class Neg[+N <: Succ[Nat]]()

type Sum[X, Y] = Y match
type Sum[X <: Nat, Y] = Y match
case Zero => X
case Succ[y] => Sum[Succ[X], y]

type IntSum[A, B] = B match
case Neg[b] => IntSumNeg[A, b]

type IntSumNeg[A, B] = A match
case Neg[a] => Neg[Sum[a, B]]
case Neg[a] => Negate[Sum[a, B]]

type Negate[A] = A match
case Zero => Zero
case Succ[x] => Neg[A & Succ[x]]

type One = Succ[Zero]
type Two = Succ[One]
Expand Down
8 changes: 6 additions & 2 deletions tests/pos/i15926.min.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ final case class Succ[+N <: Nat]() extends Nat

final case class Neg[+N <: Succ[Nat]]()

type Sum[X, Y] = Y match
type Sum[X <: Nat, Y] <: Nat = Y match
case Zero => X
case Succ[y] => Sum[Succ[X], y]

type IntSum[A, B] = B match
case Neg[b] => A match
case Neg[a] => Neg[Sum[a, b]]
case Neg[a] => Negate[Sum[a, b]]

type Negate[A] = A match
case Zero => Zero
case Succ[x] => Neg[A & Succ[x]]

type One = Succ[Zero]
type Two = Succ[One]
Expand Down
14 changes: 9 additions & 5 deletions tests/pos/i15926.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ type NatSum[X <: NatT, Y <: NatT] <: NatT = Y match
type NatDif[X <: NatT, Y <: NatT] <: IntT = Y match
case Zero => X
case Succ[y] => X match
case Zero => Minus[Y]
case Zero => Minus[Y & Succ[y]]
case Succ[x] => NatDif[x, y]

type Sum[X <: IntT, Y <: IntT] <: IntT = Y match
case Zero => X
case Minus[y] => X match
case Minus[x] => Minus[NatSum[x, y]]
case _ => NatDif[X, y]
case Minus[x] => Negate[NatSum[x, y]]
case _ => NatDif[X & NatT, y]
case _ => X match
case Minus[x] => NatDif[Y, x]
case _ => NatSum[X, Y]
case Minus[x] => NatDif[Y & NatT, x]
case _ => NatSum[X & NatT, Y & NatT]

type Negate[A] <: IntT = A match
case Zero => Zero
case Succ[x] => Minus[A & Succ[x]]
4 changes: 2 additions & 2 deletions tests/pos/i16596.orig.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import scala.compiletime.ops.int

type Count0[N,T] <: Tuple = (N,T) match
type Count0[N <: Int,T] <: Tuple = (N,T) match
case (0,_) => EmptyTuple
case (N,String) => String *: Count0[int.-[N, 1], String]
case (N,Int) => Int *: Count0[int.-[N, 1], Int]
case (N,Float) => Float *: Count0[int.-[N, 1], Float]
case (N,Double) => Double *: Count0[int.-[N, 1], Double]


type Count1[N,T] <: Tuple = (N,T) match
type Count1[N <: Int,T] <: Tuple = (N,T) match
case (0,T) => EmptyTuple
case (N,String) => String *: Count1[int.-[N, 1], String]
case (N,Int) => Int *: Count1[int.-[N, 1], Int]
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/i16596.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import scala.compiletime.ops.int, int.-

type Count[N, T] <: Tuple = (N, T) match
type Count[N <: Int, T] <: Tuple = (N, T) match
case (0, T) => EmptyTuple
case (N, T) => T *: Count[N - 1, T]

Expand Down
2 changes: 1 addition & 1 deletion tests/pos/i16706.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import scala.deriving.Mirror
import scala.reflect.ClassTag

type TupleUnionLub[T <: Tuple, Lub, Acc <: Lub] <: Lub = T match {
case (h & Lub) *: t => TupleUnionLub[t, Lub, Acc | h]
case (h & Lub) *: t => TupleUnionLub[t, Lub, (Acc | h) & Lub]
case EmptyTuple => Acc
}

Expand Down
16 changes: 0 additions & 16 deletions tests/pos/i17257.min.scala

This file was deleted.

29 changes: 29 additions & 0 deletions tests/pos/mini-onnx/Indices.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import scala.compiletime.ops.string.+
import scala.compiletime.ops.any

type Index = Int & Singleton

sealed trait Indices

final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices:
override def toString = s"$head ::: $tail"

sealed trait INil extends Indices
case object INil extends INil

object Indices:
type ToString[X <: Indices] <: String = X match
case INil => "INil"
case head ::: tail => any.ToString[head] + " ::: " + ToString[tail]

type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match
case head ::: tail => head match
case Needle => true
case _ => Contains[tail, Needle]
case INil => false

type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match
case INil => INil
case head ::: tail => head match
case Value => RemoveValue[tail, Value]
case _ => head ::: RemoveValue[tail, Value]
Loading
Loading