From 538a8bdcdeed720845ce808f184103f0d7be91a1 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Fri, 13 Oct 2023 17:10:35 +0100 Subject: [PATCH] Add some very basic GADT constraints from type cases --- .../tools/dotc/transform/PostTyper.scala | 7 +- .../src/dotty/tools/dotc/typer/Namer.scala | 5 +- .../src/dotty/tools/dotc/typer/Typer.scala | 8 +- tests/pos/mini-onnx/Indices.scala | 29 ++++ tests/pos/mini-onnx/Shape.scala | 68 +++++++++ .../pos/mini-onnx/TensorShapeDenotation.scala | 24 ++++ tests/pos/mini-onnx/Tensors.scala | 130 ++++++++++++++++++ tests/pos/nano-onnx.scala | 20 +++ 8 files changed, 287 insertions(+), 4 deletions(-) create mode 100644 tests/pos/mini-onnx/Indices.scala create mode 100644 tests/pos/mini-onnx/Shape.scala create mode 100644 tests/pos/mini-onnx/TensorShapeDenotation.scala create mode 100644 tests/pos/mini-onnx/Tensors.scala create mode 100644 tests/pos/nano-onnx.scala diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index a9130ea3354c..e4761124763c 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -484,11 +484,14 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case m @ MatchTypeTree(bounds, selector, cases) => // Analog to the case above for match types def transformCase(x: CaseDef): CaseDef = - cpy.CaseDef(tree)( + val gadtCtx = x.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match + case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt)) + case None => ctx + inContext(gadtCtx)(cpy.CaseDef(tree)( withMode(Mode.Pattern)(transform(x.pat)), transform(x.guard), transform(x.body), - ) + )) cpy.MatchTypeTree(tree)( super.transform(bounds), super.transform(selector), diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 8ed881ca0d81..32250d7a7fef 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -984,7 +984,10 @@ class Namer { typer: Typer => override final def typeSig(sym: Symbol): Type = val tparamSyms = completerTypeParams(sym)(using ictx) - given ctx: Context = nestedCtx.nn + given ctx: Context = if tparamSyms.isEmpty then nestedCtx.nn else + given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds + ctx.gadtState.addToConstraint(tparamSyms) + ctx def abstracted(tp: TypeBounds): TypeBounds = HKTypeLambda.boundsFromParams(tparamSyms, tp) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 034d2ad1271e..d7aa994ab9bc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1965,13 +1965,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } if !ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef then withMode(Mode.GadtConstraintInference) { - TypeComparer.constrainPatternType(pat1.tpe, selType) + selType match + case scr: TypeRef if ctx.gadt.contains(scr.symbol) => pat1.tpe match + case pat: TypeRef => scr <:< pat + case _ => TypeComparer.constrainPatternType(pat1.tpe, selType) + case _ => TypeComparer.constrainPatternType(pat1.tpe, selType) } val pat2 = indexPattern(cdef).transform(pat1) var body1 = typedType(cdef.body, pt) if !body1.isType then assert(ctx.reporter.errorsReported) body1 = TypeTree(errorType(em"", cdef.srcPos)) + else if ctx.gadt.isNarrowing then + pat2.putAttachment(InferredGadtConstraints, ctx.gadt) assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1) } caseRest(using ctx.fresh.setFreshGADTBounds.setNewScope) diff --git a/tests/pos/mini-onnx/Indices.scala b/tests/pos/mini-onnx/Indices.scala new file mode 100644 index 000000000000..5e87ca3a384d --- /dev/null +++ b/tests/pos/mini-onnx/Indices.scala @@ -0,0 +1,29 @@ +import scala.compiletime.ops.string.+ +import scala.compiletime.ops.any + +type Index = Int & Singleton + +sealed trait Indices + +final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices: + override def toString = s"$head ::: $tail" + +sealed trait INil extends Indices +case object INil extends INil + +object Indices: + type ToString[X <: Indices] <: String = X match + case INil => "INil" + case head ::: tail => any.ToString[head] + " ::: " + ToString[tail] + + type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match + case head ::: tail => head match + case Needle => true + case _ => Contains[tail, Needle] + case INil => false + + type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match + case INil => INil + case head ::: tail => head match + case Value => RemoveValue[tail, Value] + case _ => head ::: RemoveValue[tail, Value] diff --git a/tests/pos/mini-onnx/Shape.scala b/tests/pos/mini-onnx/Shape.scala new file mode 100644 index 000000000000..83e0aef6f6b1 --- /dev/null +++ b/tests/pos/mini-onnx/Shape.scala @@ -0,0 +1,68 @@ +import scala.compiletime.ops.int.{S, +, <, <=, *} +import scala.compiletime.ops.boolean.&& + +type Dimension = Int & Singleton + +sealed trait Shape extends Product with Serializable + +final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape: + override def toString = (head: Any) match + case _ #: _ => s"($head) #: $tail" + case _ => s"$head #: $tail" + +sealed trait SNil extends Shape +case object SNil extends SNil + +object Shape: + def scalar: SNil = SNil + + type Concat[X <: Shape, Y <: Shape] <: Shape = X match + case SNil => Y + case head #: tail => head #: Concat[tail, Y] + + type Reverse[X <: Shape] <: Shape = X match + case SNil => SNil + case head #: tail => Concat[Reverse[tail], head #: SNil] + + type NumElements[X <: Shape] <: Int = X match + case SNil => 1 + case head #: tail => head * NumElements[tail] + + type Rank[X <: Shape] <: Int = X match + case SNil => 0 + case head #: tail => Rank[tail] + 1 + + type IsEmpty[X <: Shape] <: Boolean = X match + case SNil => true + case _ #: _ => false + + type Head[X <: Shape] <: Dimension = X match { case head #: _ => head } + type Tail[X <: Shape] <: Shape = X match { case _ #: tail => tail } + + type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match + case None.type => SNil + case Indices => ReduceLoop[S, Axes, 0] + + protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match + case head #: tail => Indices.Contains[ToRemove, I] match + case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]] + case false => head #: ReduceLoop[tail, ToRemove, S[I]] + case SNil => ToRemove match { case INil => SNil } + + type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S]) + + type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match + case true => RemoveIndexLoop[RemoveFrom, I, 0] + + protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match + case head #: tail => Current match + case I => tail + case _ => head #: RemoveIndexLoop[tail, I, S[Current]] + + type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match + case SNil => SNil + case head #: tail => F[head] #: Map[tail, F] + + type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match + case SNil => Z + case head #: tail => FoldLeft[B, tail, F[Z, head], F] diff --git a/tests/pos/mini-onnx/TensorShapeDenotation.scala b/tests/pos/mini-onnx/TensorShapeDenotation.scala new file mode 100644 index 000000000000..6d9978c34f98 --- /dev/null +++ b/tests/pos/mini-onnx/TensorShapeDenotation.scala @@ -0,0 +1,24 @@ +import scala.compiletime.ops.int.S + +type DimensionDenotation = String & Singleton + +sealed trait TensorShapeDenotation extends Product with Serializable + +final case class ##:[+H <: DimensionDenotation, +T <: TensorShapeDenotation](head: H, tail: T) extends TensorShapeDenotation: + override def toString = (head: Any) match + case _ ##: _ => s"($head) ##: $tail" + case _ => s"$head ##: $tail" + +sealed trait TSNil extends TensorShapeDenotation +case object TSNil extends TSNil + +object TensorShapeDenotation: + type Reduce[S <: TensorShapeDenotation, Axes <: None.type | Indices] <: TensorShapeDenotation = Axes match + case None.type => TSNil + case Indices => ReduceLoop[S, Axes, 0] + + protected type ReduceLoop[RemoveFrom <: TensorShapeDenotation, ToRemove <: Indices, I <: Index] <: TensorShapeDenotation = RemoveFrom match + case head ##: tail => Indices.Contains[ToRemove, I] match + case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]] + case false => head ##: ReduceLoop[tail, ToRemove, S[I]] + case TSNil => ToRemove match { case INil => TSNil } diff --git a/tests/pos/mini-onnx/Tensors.scala b/tests/pos/mini-onnx/Tensors.scala new file mode 100644 index 000000000000..b4b80271827d --- /dev/null +++ b/tests/pos/mini-onnx/Tensors.scala @@ -0,0 +1,130 @@ +import scala.compiletime.ops.int.* + +object Tensors: + import Shape.Reverse + + type Supported = Int | Long | Float | Double | Byte | Short | Boolean | String + + type TensorTypeDenotation = String & Singleton + + type Axes = Tuple3[TensorTypeDenotation, TensorShapeDenotation, Shape] + + opaque type Tensor[T <: Supported, +Ax <: Axes] = Tuple2[Array[T], Ax] + + type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A] + + type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = KeepDims match + case true => ReduceKeepDims[S, AxisIndices] + case false => Shape.Reduce[S, AxisIndices] + + type KeepOrReduceDimDenotations[Td <: TensorShapeDenotation, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: TensorShapeDenotation = KeepDims match + case true => Td + case false => TensorShapeDenotation.Reduce[Td, AxisIndices] + + type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match + case None.type => SNil + case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0] + + protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match + case head #: tail => Indices.Contains[ToReplace, I] match + case true => 1 #: ReduceKeepDimsLoop[tail, Indices.RemoveValue[ToReplace, I], S[I]] + case false => head #: ReduceKeepDimsLoop[tail, ToReplace, S[I]] + case SNil => ToReplace match { case INil => SNil } + + type AddGivenAxisSize[S <: Shape, S1 <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match + case None.type => SNil + case Indices => AddGivenAxisSizeLoop[S, S1, AxisIndices, 0] + + protected type AddGivenAxisSizeLoop[First <: Shape, Second <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = First match + case head #: tail => Indices.Contains[AxisIndex, I] match + case true => Second match + case secondHead #: secondTail => (head + secondHead) #: AddGivenAxisSizeLoop[tail, secondTail, Indices.RemoveValue[AxisIndex, I], S[I]] + case SNil => AxisIndex match { case INil => SNil } + case false => Second match + case secondHead #: secondTail => (head) #: AddGivenAxisSizeLoop[tail, secondTail, AxisIndex, S[I]] + case SNil => AxisIndex match { case INil => SNil } + + type UnsqueezeShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match + case None.type => SNil + case Indices => UnsqueezeShapeLoop[S, AxisIndex, 0] + + protected type UnsqueezeShapeLoop[ToUnsqueeze <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = ToUnsqueeze match + case head #: tail => Indices.Contains[AxisIndex, I] match + case true => 1 #: head #: UnsqueezeShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I]] + case false => head #: UnsqueezeShapeLoop[tail, AxisIndex, S[I]] + case SNil => AxisIndex match { case INil => SNil } + + type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match + case None.type => SNil + case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices] + + protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match + case head #: tail => Indices.Contains[AxisIndex, I] match + case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices] + case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices] + case SNil => AxisIndex match { case INil => SNil } + + type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0] + + type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] <: Dimension = AxisIndices match + case head ::: tail => IndicesSizeLoop[tail, S[Acc]] + case INil => Acc + + type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match + case None.type => SNil + case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1] + + protected type FlattenedShapeLoop[ToFlatten <: Shape, AxisIndex <: Indices, I <: Index, Acc <: Index] <: Shape = ToFlatten match + case head #: tail => Indices.Contains[AxisIndex, I] match + case true => Acc #: FlattenedShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], head] + case false => FlattenedShapeLoop[tail, AxisIndex, S[I], head * Acc] + case SNil => AxisIndex match { case INil => Acc #: SNil } + + type SlicedShape[AxisIndicesStarts <: None.type | Indices, AxisIndicesEnds <: None.type | Indices] <: Shape = AxisIndicesStarts match + case None.type => SNil + case Indices => AxisIndicesEnds match + case None.type => SNil + case Indices => SlicedShapeLoop[AxisIndicesStarts, AxisIndicesEnds] + + protected type SlicedShapeLoop[Starts <: Indices, Ends <: Indices] <: Shape = Starts match + case head ::: tail => Ends match + case endsHead ::: endsTail => (endsHead - head) #: SlicedShapeLoop[tail, endsTail] + case INil => SNil + case INil => Ends match { case INil => SNil } + + type PaddedShape[PadFrom <: Shape, AxisBefore <: None.type | Shape, AxisAfter <: None.type | Shape] <: Shape = AxisBefore match + case None.type => PadFrom + case Shape => AxisAfter match + case None.type => PadFrom + case Shape => Reverse[PaddedShapeLoop[Reverse[PadFrom], Reverse[AxisBefore], Reverse[AxisAfter]]] + + protected type PaddedShapeLoop[PadFrom <: Shape, Before <: Shape, After <: Shape] <: Shape = Before match + case head #: tail => After match + case afterHead #: afterTail => PadFrom match + case padFromHead #: padFromTail => (head + padFromHead + afterHead) #: PaddedShapeLoop[padFromTail, tail, afterTail] + case SNil => SNil + case SNil => SNil + case SNil => After match + case SNil => PadFrom match + case padFromHead #: padFromTail => padFromHead #: PaddedShapeLoop[padFromTail, SNil, SNil] + case SNil => SNil + + type TiledShape[TileFrom <: Shape, AxisRepeats <: None.type | Indices] <: Shape = AxisRepeats match + case None.type => SNil + case Indices => TiledShapeLoop[TileFrom, AxisRepeats] + + protected type TiledShapeLoop[TileFrom <: Shape, Repeats <: Indices] <: Shape = Repeats match + case head ::: tail => TileFrom match + case tileFromHead #: tileFromTail => (head * tileFromHead) #: TiledShapeLoop[tileFromTail, tail] + case SNil => SNil + case INil => SNil + + type PoolShape[From <: Shape, KernelShape <: None.type | Shape] <: Shape = KernelShape match + case None.type => SNil + case Shape => Reverse[PoolShapeLoop[Reverse[From], Reverse[KernelShape]]] + + protected type PoolShapeLoop[From <: Shape, KernelShape <: Shape] <: Shape = KernelShape match + case head #: tail => From match + case fromHead #: fromTail => (fromHead - head + 1) #: PoolShapeLoop[fromTail, tail] + case SNil => SNil + case SNil => From diff --git a/tests/pos/nano-onnx.scala b/tests/pos/nano-onnx.scala new file mode 100644 index 000000000000..6ef2d43cd033 --- /dev/null +++ b/tests/pos/nano-onnx.scala @@ -0,0 +1,20 @@ +import scala.compiletime.ops.int.* + +type Index = Int & Singleton +type Dimension = Int & Singleton + +sealed trait Indices extends Product with Serializable +sealed trait Shape extends Product with Serializable +final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices +final case class #:[+H <: Dimension, +T <: Shape ](head: H, tail: T) extends Shape +sealed trait INil extends Indices; case object INil extends INil +sealed trait SNil extends Shape; case object SNil extends SNil + +object Ts: + type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match + case None.type => SNil + case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0] + + protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match + case head #: tail => ReduceKeepDimsLoop[tail, ToReplace, S[I]] + case SNil => ToReplace match { case INil => SNil }