Skip to content

Commit

Permalink
Minimal support for dependent case classes
Browse files Browse the repository at this point in the history
This lets us write:

    trait A:
      type B

    case class CC(a: A, b: a.B)

Pattern matching works but isn't dependent yet:

    x match
      case CC(a, b) =>
        val a1: A = a
        // Dependent pattern matching is not currently supported
        // val b1: a1.B = b
        val b1 = b // Type is CC#a.B

(for my usecase this isn't a problem, I'm working on a type constraint API
which lets me write things like `case class CC(a: Int, b: Int
GreaterThan[a.type])`)

Because case class pattern matching relies on the product selectors `_N`, making
it dependent is a bit tricky, currently we generate:

    case class CC(a: A, b: a.B):
      def _1: A = a
      def _2: a.B = b

So the type of `_2` is not obviously related to the type of `_1`, we probably
need to change what we generate into:

    case class CC(a: A, b: a.B):
      @uncheckedStable def _1: a.type = a
      def _2: _1.B = b

But this can be done in a separate PR.

Fixes #8073.
  • Loading branch information
smarter committed Oct 5, 2024
1 parent db81e44 commit e5327ac
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 53 deletions.
73 changes: 42 additions & 31 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,53 +498,64 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
/** The class
*
* ```
* case class C[T <: U](x: T, y: String*)
* trait U:
* type Elem
*
* case class C[T <: U](a: T, b: a.Elem, c: String*)
* ```
*
* gets the `fromProduct` method:
*
* ```
* def fromProduct(x$0: Product): MirroredMonoType =
* new C[U](
* x$0.productElement(0).asInstanceOf[U],
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
* val a$1 = x$0.productElement(0).asInstanceOf[U]
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
* val c$1 = x$0.productElement(2).asInstanceOf[Seq[String]]
* new C[U](a$1, b$1, c$1*)
* ```
* where
* ```
* type MirroredMonoType = C[?]
* ```
*/
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
def extractParams(tpe: Type): List[Type] =
tpe.asInstanceOf[MethodType].paramInfos

def computeFromCaseClass: (Type, List[Type]) =
val (baseRef, baseInfo) =
val rawRef = caseClass.typeRef
val rawInfo = caseClass.primaryConstructor.info
optInfo match
case Some(info) =>
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
case _ =>
(rawRef, rawInfo)
baseInfo match
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
val classRef = optInfo match
case Some(info) => TypeRef(info.pre, caseClass)
case _ => caseClass.typeRef
val (newPrefix, constrMeth) =
val constr = TermRef(classRef, caseClass.primaryConstructor)
(constr.info: @unchecked) match
case tl: PolyType =>
val tvars = constrained(tl)
val targs = for tvar <- tvars yield
tvar.instantiate(fromBelow = false)
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
case methTpe =>
(baseRef, extractParams(methTpe))
end computeFromCaseClass

val (classRefApplied, paramInfos) = computeFromCaseClass
val elems =
for ((formal, idx) <- paramInfos.zipWithIndex) yield
val elem =
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(formal.translateFromRepeated(toArray = false))
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
New(classRefApplied, elems)
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType])
case mt: MethodType =>
(classRef, mt)

// Create symbols for the vals corresponding to each parameter
// If there are dependent parameters, the infos won't be correct yet.
val bindingSyms = constrMeth.paramRefs.map: pref =>
newSymbol(ctx.owner, pref.paramName.freshened, Synthetic,
pref.underlying.translateFromRepeated(toArray = false), coord = ctx.owner.span.focus)
val bindingRefs = bindingSyms.map(TermRef(NoPrefix, _))
// Fix the infos for dependent parameters
if constrMeth.isParamDependent then
bindingSyms.foreach: bindingSym =>
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)

val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
ValDef(bindingSym,
productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(bindingSym.info))

val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
val refTree = ref(bindingRef)
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
Block(
bindingDefs,
New(newPrefix, newArgs)
)
end fromProductBody

/** For an enum T:
Expand Down
14 changes: 1 addition & 13 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1925,9 +1925,7 @@ class Namer { typer: Typer =>
if isConstructor then
// set result type tree to unit, but take the current class as result type of the symbol
typedAheadType(ddef.tpt, defn.UnitType)
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
wrapMethType(effectiveResultType(sym, paramSymss))
else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then
// set every context bound evidence parameter of a given companion method
// to be tracked, provided it has a type that has an abstract type member.
Expand Down Expand Up @@ -1976,16 +1974,6 @@ class Namer { typer: Typer =>
ddef.trailingParamss.foreach(completeParams)
end completeTrailingParamss

/** Checks an implementation restriction on case classes. */
def checkCaseClassParamDependencies(mt: Type, cls: Symbol)(using Context): Unit =
mt.stripPoly match
case mt: MethodType if cls.is(Case) && mt.isParamDependent =>
// See issue #8073 for background
report.error(
em"""Implementation restriction: case classes cannot have dependencies between parameters""",
cls.srcPos)
case _ =>

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
*/
Expand Down
8 changes: 0 additions & 8 deletions tests/neg/i8069.scala

This file was deleted.

2 changes: 1 addition & 1 deletion tests/run-macros/tasty-extractors-2.check
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit")
Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("<init>", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil)), None, List(DefDef("a", Nil, Inferred(), Some(Literal(IntConstant(0))))))), Literal(UnitConstant())))
TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit")

Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("<init>", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_hashCode"), List(This(Some("Foo")))))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), "<init>"), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), "<init>"), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), "<init>"), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), "<init>"), Nil))), ClassDef("Foo$", DefDef("<init>", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), "<init>"), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(New(Inferred()), "<init>"), Nil)))))), Literal(UnitConstant())))
Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("<init>", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_hashCode"), List(This(Some("Foo")))))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), "<init>"), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), "<init>"), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), "<init>"), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), "<init>"), Nil))), ClassDef("Foo$", DefDef("<init>", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), "<init>"), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Block(Nil, Apply(Select(New(Inferred()), "<init>"), Nil))))))), Literal(UnitConstant())))
TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit")

Inlined(None, Nil, Block(List(ClassDef("Foo1", DefDef("<init>", List(TermParamClause(List(ValDef("a", TypeIdent("Int"), None)))), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil)), None, List(ValDef("a", Inferred(), None)))), Literal(UnitConstant())))
Expand Down
86 changes: 86 additions & 0 deletions tests/run/i8073.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import scala.deriving.Mirror

trait A:
type B

object Test:
case class CC(a: A, b: a.B)

def test1(): Unit =
val generic = summon[Mirror.Of[CC]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, CC#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: CC { val a: aa.type } = CC(aa, 1).asInstanceOf[CC { val a: aa.type }] // manual `tracked`

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (A, x.a.B)]

assert(CC(aa, 1) == generic.fromProduct((aa, 1)))
assert(CC(aa, 1) == dependent.fromProduct((aa, 1)))

x match
case CC(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is CC#a.B

end test1

case class CCPoly[T <: A](a: T, b: a.B)

def test2(): Unit =
val generic = summon[Mirror.Of[CCPoly[A]]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, CCPoly[A]#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: CCPoly[aa.type] = CCPoly(aa, 1)

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (aa.type, x.a.B)]

assert(CCPoly[A](aa, 1) == generic.fromProduct((aa, 1)))
assert(CCPoly[A](aa, 1) == dependent.fromProduct((aa, 1)))

x match
case CCPoly(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is CC#a.B

end test2

enum Enum:
case EC(a: A, b: a.B)

def test3(): Unit =
val generic = summon[Mirror.Of[Enum.EC]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, Enum.EC#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: Enum.EC { val a: aa.type } = Enum.EC(aa, 1).asInstanceOf[Enum.EC { val a: aa.type }] // manual `tracked`

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (A, x.a.B)]

assert(Enum.EC(aa, 1) == generic.fromProduct((aa, 1)))
assert(Enum.EC(aa, 1) == dependent.fromProduct((aa, 1)))

x match
case Enum.EC(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is Enum.EC#a.B

end test3

def main(args: Array[String]): Unit =
test1()
test2()
test3()
86 changes: 86 additions & 0 deletions tests/run/i8073b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import scala.deriving.Mirror

trait A:
type B

// Test local mirrors
@main def Test =
case class CC(a: A, b: a.B)

def test1(): Unit =
val generic = summon[Mirror.Of[CC]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, CC#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: CC { val a: aa.type } = CC(aa, 1).asInstanceOf[CC { val a: aa.type }] // manual `tracked`

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (A, x.a.B)]

assert(CC(aa, 1) == generic.fromProduct((aa, 1)))
assert(CC(aa, 1) == dependent.fromProduct((aa, 1)))

x match
case CC(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is CC#a.B

end test1

case class CCPoly[T <: A](a: T, b: a.B)

def test2(): Unit =
val generic = summon[Mirror.Of[CCPoly[A]]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, CCPoly[A]#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: CCPoly[aa.type] = CCPoly(aa, 1)

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (aa.type, x.a.B)]

assert(CCPoly[A](aa, 1) == generic.fromProduct((aa, 1)))
assert(CCPoly[A](aa, 1) == dependent.fromProduct((aa, 1)))

x match
case CCPoly(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is CC#a.B

end test2

enum Enum:
case EC(a: A, b: a.B)

def test3(): Unit =
val generic = summon[Mirror.Of[Enum.EC]]
// No language syntax for type projection of a singleton type
// summon[generic.MirroredElemTypes =:= (A, Enum.EC#a.B)]

val aa: A { type B = Int } = new A { type B = Int }
val x: Enum.EC { val a: aa.type } = Enum.EC(aa, 1).asInstanceOf[Enum.EC { val a: aa.type }] // manual `tracked`

val dependent = summon[Mirror.Of[x.type]]
summon[dependent.MirroredElemTypes =:= (A, x.a.B)]

assert(Enum.EC(aa, 1) == generic.fromProduct((aa, 1)))
assert(Enum.EC(aa, 1) == dependent.fromProduct((aa, 1)))

x match
case Enum.EC(a, b) =>
val a1: A = a
// Dependent pattern matching is not currently supported
// val b1: a1.B = b
val b1 = b // Type is Enum.EC#a.B

end test3

test1()
test2()
test3()

0 comments on commit e5327ac

Please sign in to comment.