Skip to content

Commit

Permalink
Add some very basic GADT constraints from type cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Oct 15, 2023
1 parent 1705c34 commit 538a8bd
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 4 deletions.
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,14 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case m @ MatchTypeTree(bounds, selector, cases) =>
// Analog to the case above for match types
def transformCase(x: CaseDef): CaseDef =
cpy.CaseDef(tree)(
val gadtCtx = x.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
case None => ctx
inContext(gadtCtx)(cpy.CaseDef(tree)(
withMode(Mode.Pattern)(transform(x.pat)),
transform(x.guard),
transform(x.body),
)
))
cpy.MatchTypeTree(tree)(
super.transform(bounds),
super.transform(selector),
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,10 @@ class Namer { typer: Typer =>

override final def typeSig(sym: Symbol): Type =
val tparamSyms = completerTypeParams(sym)(using ictx)
given ctx: Context = nestedCtx.nn
given ctx: Context = if tparamSyms.isEmpty then nestedCtx.nn else
given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds
ctx.gadtState.addToConstraint(tparamSyms)
ctx

def abstracted(tp: TypeBounds): TypeBounds =
HKTypeLambda.boundsFromParams(tparamSyms, tp)
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1965,13 +1965,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
if !ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef then
withMode(Mode.GadtConstraintInference) {
TypeComparer.constrainPatternType(pat1.tpe, selType)
selType match
case scr: TypeRef if ctx.gadt.contains(scr.symbol) => pat1.tpe match
case pat: TypeRef => scr <:< pat
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
}
val pat2 = indexPattern(cdef).transform(pat1)
var body1 = typedType(cdef.body, pt)
if !body1.isType then
assert(ctx.reporter.errorsReported)
body1 = TypeTree(errorType(em"<error: not a type>", cdef.srcPos))
else if ctx.gadt.isNarrowing then
pat2.putAttachment(InferredGadtConstraints, ctx.gadt)
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
}
caseRest(using ctx.fresh.setFreshGADTBounds.setNewScope)
Expand Down
29 changes: 29 additions & 0 deletions tests/pos/mini-onnx/Indices.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import scala.compiletime.ops.string.+
import scala.compiletime.ops.any

type Index = Int & Singleton

sealed trait Indices

final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices:
override def toString = s"$head ::: $tail"

sealed trait INil extends Indices
case object INil extends INil

object Indices:
type ToString[X <: Indices] <: String = X match
case INil => "INil"
case head ::: tail => any.ToString[head] + " ::: " + ToString[tail]

type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match
case head ::: tail => head match
case Needle => true
case _ => Contains[tail, Needle]
case INil => false

type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match
case INil => INil
case head ::: tail => head match
case Value => RemoveValue[tail, Value]
case _ => head ::: RemoveValue[tail, Value]
68 changes: 68 additions & 0 deletions tests/pos/mini-onnx/Shape.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import scala.compiletime.ops.int.{S, +, <, <=, *}
import scala.compiletime.ops.boolean.&&

type Dimension = Int & Singleton

sealed trait Shape extends Product with Serializable

final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape:
override def toString = (head: Any) match
case _ #: _ => s"($head) #: $tail"
case _ => s"$head #: $tail"

sealed trait SNil extends Shape
case object SNil extends SNil

object Shape:
def scalar: SNil = SNil

type Concat[X <: Shape, Y <: Shape] <: Shape = X match
case SNil => Y
case head #: tail => head #: Concat[tail, Y]

type Reverse[X <: Shape] <: Shape = X match
case SNil => SNil
case head #: tail => Concat[Reverse[tail], head #: SNil]

type NumElements[X <: Shape] <: Int = X match
case SNil => 1
case head #: tail => head * NumElements[tail]

type Rank[X <: Shape] <: Int = X match
case SNil => 0
case head #: tail => Rank[tail] + 1

type IsEmpty[X <: Shape] <: Boolean = X match
case SNil => true
case _ #: _ => false

type Head[X <: Shape] <: Dimension = X match { case head #: _ => head }
type Tail[X <: Shape] <: Shape = X match { case _ #: tail => tail }

type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match
case None.type => SNil
case Indices => ReduceLoop[S, Axes, 0]

protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match
case head #: tail => Indices.Contains[ToRemove, I] match
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]]
case false => head #: ReduceLoop[tail, ToRemove, S[I]]
case SNil => ToRemove match { case INil => SNil }

type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S])

type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match
case true => RemoveIndexLoop[RemoveFrom, I, 0]

protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match
case head #: tail => Current match
case I => tail
case _ => head #: RemoveIndexLoop[tail, I, S[Current]]

type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match
case SNil => SNil
case head #: tail => F[head] #: Map[tail, F]

type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match
case SNil => Z
case head #: tail => FoldLeft[B, tail, F[Z, head], F]
24 changes: 24 additions & 0 deletions tests/pos/mini-onnx/TensorShapeDenotation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import scala.compiletime.ops.int.S

type DimensionDenotation = String & Singleton

sealed trait TensorShapeDenotation extends Product with Serializable

final case class ##:[+H <: DimensionDenotation, +T <: TensorShapeDenotation](head: H, tail: T) extends TensorShapeDenotation:
override def toString = (head: Any) match
case _ ##: _ => s"($head) ##: $tail"
case _ => s"$head ##: $tail"

sealed trait TSNil extends TensorShapeDenotation
case object TSNil extends TSNil

object TensorShapeDenotation:
type Reduce[S <: TensorShapeDenotation, Axes <: None.type | Indices] <: TensorShapeDenotation = Axes match
case None.type => TSNil
case Indices => ReduceLoop[S, Axes, 0]

protected type ReduceLoop[RemoveFrom <: TensorShapeDenotation, ToRemove <: Indices, I <: Index] <: TensorShapeDenotation = RemoveFrom match
case head ##: tail => Indices.Contains[ToRemove, I] match
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]]
case false => head ##: ReduceLoop[tail, ToRemove, S[I]]
case TSNil => ToRemove match { case INil => TSNil }
130 changes: 130 additions & 0 deletions tests/pos/mini-onnx/Tensors.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import scala.compiletime.ops.int.*

object Tensors:
import Shape.Reverse

type Supported = Int | Long | Float | Double | Byte | Short | Boolean | String

type TensorTypeDenotation = String & Singleton

type Axes = Tuple3[TensorTypeDenotation, TensorShapeDenotation, Shape]

opaque type Tensor[T <: Supported, +Ax <: Axes] = Tuple2[Array[T], Ax]

type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A]

type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = KeepDims match
case true => ReduceKeepDims[S, AxisIndices]
case false => Shape.Reduce[S, AxisIndices]

type KeepOrReduceDimDenotations[Td <: TensorShapeDenotation, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: TensorShapeDenotation = KeepDims match
case true => Td
case false => TensorShapeDenotation.Reduce[Td, AxisIndices]

type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
case None.type => SNil
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0]

protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match
case head #: tail => Indices.Contains[ToReplace, I] match
case true => 1 #: ReduceKeepDimsLoop[tail, Indices.RemoveValue[ToReplace, I], S[I]]
case false => head #: ReduceKeepDimsLoop[tail, ToReplace, S[I]]
case SNil => ToReplace match { case INil => SNil }

type AddGivenAxisSize[S <: Shape, S1 <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
case None.type => SNil
case Indices => AddGivenAxisSizeLoop[S, S1, AxisIndices, 0]

protected type AddGivenAxisSizeLoop[First <: Shape, Second <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = First match
case head #: tail => Indices.Contains[AxisIndex, I] match
case true => Second match
case secondHead #: secondTail => (head + secondHead) #: AddGivenAxisSizeLoop[tail, secondTail, Indices.RemoveValue[AxisIndex, I], S[I]]
case SNil => AxisIndex match { case INil => SNil }
case false => Second match
case secondHead #: secondTail => (head) #: AddGivenAxisSizeLoop[tail, secondTail, AxisIndex, S[I]]
case SNil => AxisIndex match { case INil => SNil }

type UnsqueezeShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match
case None.type => SNil
case Indices => UnsqueezeShapeLoop[S, AxisIndex, 0]

protected type UnsqueezeShapeLoop[ToUnsqueeze <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = ToUnsqueeze match
case head #: tail => Indices.Contains[AxisIndex, I] match
case true => 1 #: head #: UnsqueezeShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I]]
case false => head #: UnsqueezeShapeLoop[tail, AxisIndex, S[I]]
case SNil => AxisIndex match { case INil => SNil }

type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match
case None.type => SNil
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices]

protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match
case head #: tail => Indices.Contains[AxisIndex, I] match
case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices]
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices]
case SNil => AxisIndex match { case INil => SNil }

type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]

type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] <: Dimension = AxisIndices match
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
case INil => Acc

type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match
case None.type => SNil
case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1]

protected type FlattenedShapeLoop[ToFlatten <: Shape, AxisIndex <: Indices, I <: Index, Acc <: Index] <: Shape = ToFlatten match
case head #: tail => Indices.Contains[AxisIndex, I] match
case true => Acc #: FlattenedShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], head]
case false => FlattenedShapeLoop[tail, AxisIndex, S[I], head * Acc]
case SNil => AxisIndex match { case INil => Acc #: SNil }

type SlicedShape[AxisIndicesStarts <: None.type | Indices, AxisIndicesEnds <: None.type | Indices] <: Shape = AxisIndicesStarts match
case None.type => SNil
case Indices => AxisIndicesEnds match
case None.type => SNil
case Indices => SlicedShapeLoop[AxisIndicesStarts, AxisIndicesEnds]

protected type SlicedShapeLoop[Starts <: Indices, Ends <: Indices] <: Shape = Starts match
case head ::: tail => Ends match
case endsHead ::: endsTail => (endsHead - head) #: SlicedShapeLoop[tail, endsTail]
case INil => SNil
case INil => Ends match { case INil => SNil }

type PaddedShape[PadFrom <: Shape, AxisBefore <: None.type | Shape, AxisAfter <: None.type | Shape] <: Shape = AxisBefore match
case None.type => PadFrom
case Shape => AxisAfter match
case None.type => PadFrom
case Shape => Reverse[PaddedShapeLoop[Reverse[PadFrom], Reverse[AxisBefore], Reverse[AxisAfter]]]

protected type PaddedShapeLoop[PadFrom <: Shape, Before <: Shape, After <: Shape] <: Shape = Before match
case head #: tail => After match
case afterHead #: afterTail => PadFrom match
case padFromHead #: padFromTail => (head + padFromHead + afterHead) #: PaddedShapeLoop[padFromTail, tail, afterTail]
case SNil => SNil
case SNil => SNil
case SNil => After match
case SNil => PadFrom match
case padFromHead #: padFromTail => padFromHead #: PaddedShapeLoop[padFromTail, SNil, SNil]
case SNil => SNil

type TiledShape[TileFrom <: Shape, AxisRepeats <: None.type | Indices] <: Shape = AxisRepeats match
case None.type => SNil
case Indices => TiledShapeLoop[TileFrom, AxisRepeats]

protected type TiledShapeLoop[TileFrom <: Shape, Repeats <: Indices] <: Shape = Repeats match
case head ::: tail => TileFrom match
case tileFromHead #: tileFromTail => (head * tileFromHead) #: TiledShapeLoop[tileFromTail, tail]
case SNil => SNil
case INil => SNil

type PoolShape[From <: Shape, KernelShape <: None.type | Shape] <: Shape = KernelShape match
case None.type => SNil
case Shape => Reverse[PoolShapeLoop[Reverse[From], Reverse[KernelShape]]]

protected type PoolShapeLoop[From <: Shape, KernelShape <: Shape] <: Shape = KernelShape match
case head #: tail => From match
case fromHead #: fromTail => (fromHead - head + 1) #: PoolShapeLoop[fromTail, tail]
case SNil => SNil
case SNil => From
20 changes: 20 additions & 0 deletions tests/pos/nano-onnx.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import scala.compiletime.ops.int.*

type Index = Int & Singleton
type Dimension = Int & Singleton

sealed trait Indices extends Product with Serializable
sealed trait Shape extends Product with Serializable
final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices
final case class #:[+H <: Dimension, +T <: Shape ](head: H, tail: T) extends Shape
sealed trait INil extends Indices; case object INil extends INil
sealed trait SNil extends Shape; case object SNil extends SNil

object Ts:
type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
case None.type => SNil
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0]

protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match
case head #: tail => ReduceKeepDimsLoop[tail, ToReplace, S[I]]
case SNil => ToReplace match { case INil => SNil }

0 comments on commit 538a8bd

Please sign in to comment.