Skip to content

Commit 32f484a

Browse files
committed
Add additional checks for type parameters and update tests
1 parent 27d4fa2 commit 32f484a

File tree

8 files changed

+97
-24
lines changed

8 files changed

+97
-24
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

+35-10
Original file line numberDiff line numberDiff line change
@@ -2690,7 +2690,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
26902690
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
26912691
assert(!conPrivateWithin.exists || conPrivateWithin.isType, "consPrivateWithin must be a type symbol or `Symbol.noSymbol`")
26922692
checkValidFlags(clsFlags.toTypeFlags, Flags.validClassFlags)
2693-
checkValidFlags(conFlags, Flags.validClassConstructorFlags)
2693+
checkValidFlags(conFlags.toTermFlags, Flags.validClassConstructorFlags)
26942694
val cls = dotc.core.Symbols.newNormalizedClassSymbolUsingClassSymbolinParents(
26952695
owner,
26962696
name.toTypeName,
@@ -2713,33 +2713,58 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
27132713
if (conParamFlags.length <= clauseIdx) throwShapeException()
27142714
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
27152715
checkMethodOrPolyShape(res, clauseIdx + 1)
2716-
case _ =>
2716+
case other =>
2717+
xCheckMacroAssert(
2718+
other.typeSymbol == cls,
2719+
"Incorrect type returned from the innermost PolyOrMethod."
2720+
)
2721+
(other, methodType) match
2722+
case (AppliedType(tycon, args), pt: PolyType) =>
2723+
xCheckMacroAssert(
2724+
args.length == pt.typeParams.length &&
2725+
args.zip(pt.typeParams).forall {
2726+
case (arg, param) => arg == param.paramRef
2727+
},
2728+
"Constructor result type does not correspond to the declared type parameters"
2729+
)
2730+
case _ =>
2731+
xCheckMacroAssert(
2732+
!(other.isInstanceOf[AppliedType] || methodType.isInstanceOf[PolyType]),
2733+
"AppliedType has to be the innermost resultTypeExp result if and only if conMethodType returns a PolyType"
2734+
)
27172735
checkMethodOrPolyShape(methodType, clauseIdx = 0)
2736+
27182737
cls.enter(dotc.core.Symbols.newSymbol(cls, nme.CONSTRUCTOR, Flags.Synthetic | Flags.Method | conFlags, methodType, conPrivateWithin, dotty.tools.dotc.util.Spans.NoCoord))
2719-
def getParamAccessors(methodType: TypeRepr, clauseIdx: Int): List[((String, TypeRepr, Boolean, Int), Int)] =
2738+
2739+
case class ParamSymbolData(name: String, tpe: TypeRepr, isTypeParam: Boolean, clauseIdx: Int, elementIdx: Int)
2740+
def getParamSymbolsData(methodType: TypeRepr, clauseIdx: Int): List[ParamSymbolData] =
27202741
methodType match
27212742
case MethodType(paramInfosExp, resultTypeExp, res) =>
2722-
paramInfosExp.zip(resultTypeExp).map(_ :* false :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
2743+
paramInfosExp.zip(resultTypeExp).zipWithIndex.map { case ((name, tpe), elementIdx) =>
2744+
ParamSymbolData(name, tpe, isTypeParam = false, clauseIdx, elementIdx)
2745+
} ++ getParamSymbolsData(res, clauseIdx + 1)
27232746
case pt @ PolyType(paramNames, paramBounds, res) =>
2724-
paramNames.zip(paramBounds).map(_ :* true :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
2747+
paramNames.zip(paramBounds).zipWithIndex.map {case ((name, tpe), elementIdx) =>
2748+
ParamSymbolData(name, tpe, isTypeParam = true, clauseIdx, elementIdx)
2749+
} ++ getParamSymbolsData(res, clauseIdx + 1)
27252750
case result =>
27262751
List()
2727-
// Maps PolyType indexes to type parameter symbols
2752+
// Maps PolyType indexes to type parameter symbol typerefs
27282753
val paramRefMap = collection.mutable.HashMap[Int, Symbol]()
27292754
val paramRefRemapper = new Types.TypeMap {
27302755
def apply(tp: Types.Type) = tp match {
27312756
case pRef: ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
27322757
case _ => mapOver(tp)
27332758
}
27342759
}
2735-
for ((name, tpe, isType, clauseIdx), elementIdx) <- getParamAccessors(methodType, 0) do
2736-
if isType then
2737-
checkValidFlags(conParamFlags(clauseIdx)(elementIdx), Flags.validClassTypeParamFlags)
2760+
for case ParamSymbolData(name, tpe, isTypeParam, clauseIdx, elementIdx) <- getParamSymbolsData(methodType, 0) do
2761+
if isTypeParam then
2762+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTypeFlags, Flags.validClassTypeParamFlags)
27382763
val symbol = dotc.core.Symbols.newSymbol(cls, name.toTypeName, Flags.Param | Flags.Deferred | Flags.Private | Flags.PrivateLocal | Flags.Local | conParamFlags(clauseIdx)(elementIdx), tpe, conParamPrivateWithins(clauseIdx)(elementIdx))
27392764
paramRefMap.addOne(elementIdx, symbol)
27402765
cls.enter(symbol)
27412766
else
2742-
checkValidFlags(conParamFlags(clauseIdx)(elementIdx), Flags.validClassTermParamFlags)
2767+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTermFlags, Flags.validClassTermParamFlags)
27432768
val fixedType = paramRefRemapper(tpe)
27442769
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor | conParamFlags(clauseIdx)(elementIdx), fixedType, conParamPrivateWithins(clauseIdx)(elementIdx)))
27452770
for sym <- decls(cls) do cls.enter(sym)

library/src/scala/quoted/Quotes.scala

+38-9
Original file line numberDiff line numberDiff line change
@@ -3797,6 +3797,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
37973797
def classSymbol(fullName: String): Symbol
37983798

37993799
/** Generates a new class symbol for a class with a public parameterless constructor.
3800+
* For more settings, look to the other newClass methods.
38003801
*
38013802
* Example usage:
38023803
* ```
@@ -3841,13 +3842,41 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
38413842

38423843
/** Generates a new class symbol for a class with a public single term clause constructor.
38433844
*
3844-
* @param owner The owner of the class
3845-
* @param name The name of the class
3846-
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
3847-
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
3848-
* @param clsFlags extra flags with which the class symbol should be constructed.
3849-
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol.
3850-
* @param conParams constructor parameter pairs of names and types.
3845+
* Example usage:
3846+
* ```
3847+
* val name = nameExpr.valueOrAbort
3848+
* def decls(cls: Symbol): List[Symbol] =
3849+
* List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
3850+
* val parents = List(TypeTree.of[Object])
3851+
* val cls = Symbol.newClass(
3852+
* Symbol.spliceOwner,
3853+
* name,
3854+
* parents = _ => parents.map(_.tpe),
3855+
* decls,
3856+
* selfType = None,
3857+
* clsFlags = Flags.EmptyFlags,
3858+
* Symbol.noSymbol,
3859+
* List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String]))
3860+
* )
3861+
*
3862+
* val fooSym = cls.declaredMethod("foo").head
3863+
* val idxSym = cls.fieldMember("idx")
3864+
* val strSym = cls.fieldMember("str")
3865+
* val fooDef = DefDef(fooSym, argss =>
3866+
* Some('{println(s"Foo method call with (${${Ref(idxSym).asExpr}}, ${${Ref(strSym).asExpr}})")}.asTerm)
3867+
* )
3868+
* val clsDef = ClassDef(cls, parents, body = List(fooDef))
3869+
* val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), List(idxExpr.asTerm, strExpr.asTerm))
3870+
*
3871+
* Block(List(clsDef), Apply(Select(newCls, cls.methodMember("foo")(0)), Nil)).asExprOf[Unit]
3872+
* ```
3873+
* @param owner The owner of the class
3874+
* @param name The name of the class
3875+
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
3876+
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
3877+
* @param clsFlags extra flags with which the class symbol should be constructed.
3878+
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol.
3879+
* @param conParams constructor parameter pairs of names and types.
38513880
*
38523881
* Parameters assigned by the constructor can be obtained via `classSymbol.memberField`.
38533882
* This symbol starts without an accompanying definition.
@@ -3878,7 +3907,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
38783907
* val conMethodType =
38793908
* (classType: TypeRepr) => PolyType(List("T"))(_ => List(TypeBounds.empty), polyType =>
38803909
* MethodType(List("param"))((_: MethodType) => List(polyType.param(0)), (_: MethodType) =>
3881-
* classType
3910+
* AppliedType(classType, List(polyType.param(0)))
38823911
* )
38833912
* )
38843913
* val cls = Symbol.newClass(
@@ -3940,7 +3969,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
39403969
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol
39413970
* @param clsAnnotations annotations of the class
39423971
* @param conMethodType Function returning MethodOrPoly type representing the type of the constructor.
3943-
* Takes the result type as parameter which must be returned from the innermost MethodOrPoly.
3972+
* Takes the result type as parameter which must be returned from the innermost MethodOrPoly and have type parameters applied if those are used.
39443973
* PolyType may only represent the first clause of the constructor.
39453974
* @param conFlags extra flags with which the constructor symbol should be constructed. Can be `Synthetic` | `Method` | `Private` | `Protected` | `PrivateLocal` | `Local`
39463975
* @param conPrivateWithin the symbol within which the constructor for this new class symbol should be private. May be noSymbol.

tests/run-macros/newClassParams/Macro_1.scala

+18-3
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,26 @@ private def makeClassAndCallExpr(nameExpr: Expr[String], idxExpr: Expr[Int], str
88

99
val name = nameExpr.valueOrAbort
1010

11-
def decls(cls: Symbol): List[Symbol] = List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
11+
def decls(cls: Symbol): List[Symbol] =
12+
List(Symbol.newMethod(cls, "foo", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit])))
1213
val parents = List(TypeTree.of[Object])
13-
val cls = Symbol.newClass(Symbol.spliceOwner, name, parents = _ => parents.map(_.tpe), decls, selfType = None, Flags.EmptyFlags, Symbol.noSymbol, List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String])))
14+
val cls = Symbol.newClass(
15+
Symbol.spliceOwner,
16+
name,
17+
parents = _ => parents.map(_.tpe),
18+
decls,
19+
selfType = None,
20+
clsFlags = Flags.EmptyFlags,
21+
Symbol.noSymbol,
22+
List(("idx", TypeRepr.of[Int]), ("str", TypeRepr.of[String]))
23+
)
1424

15-
val fooDef = DefDef(cls.methodMember("foo")(0), argss => Some('{println(s"Foo method call with (${${Ref(cls.fieldMember("idx")).asExpr}}, ${${Ref(cls.fieldMember("str")).asExpr}})")}.asTerm))
25+
val fooSym = cls.declaredMethod("foo").head
26+
val idxSym = cls.fieldMember("idx")
27+
val strSym = cls.fieldMember("str")
28+
val fooDef = DefDef(fooSym, argss =>
29+
Some('{println(s"Foo method call with (${${Ref(idxSym).asExpr}}, ${${Ref(strSym).asExpr}})")}.asTerm)
30+
)
1631
val clsDef = ClassDef(cls, parents, body = List(fooDef))
1732
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), List(idxExpr.asTerm, strExpr.asTerm))
1833

tests/run-macros/newClassTraitAndAbstract/Macro_1.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ private def makeClassExpr(using Quotes)(
3333
val conMethodType =
3434
(classType: TypeRepr) => PolyType(List("A", "B"))(
3535
_ => List(TypeBounds.empty, TypeBounds.upper(TypeRepr.of[Int])),
36-
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) => classType)
36+
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) =>
37+
AppliedType(classType, List(polyType.param(0), polyType.param(1)))
38+
)
3739
)
3840

3941
val traitSymbol = Symbol.newClass(

tests/run-macros/newClassTypeParams/Macro_1.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ private def makeClassExpr(nameExpr: Expr[String])(using Quotes): Expr[Any] = {
1111
val conMethodType =
1212
(classType: TypeRepr) => PolyType(List("A", "B"))(
1313
_ => List(TypeBounds.empty, TypeBounds.upper(TypeRepr.of[Int])),
14-
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) => classType)
14+
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) =>
15+
AppliedType(classType, List(polyType.param(0), polyType.param(1)))
16+
)
1517
)
1618

1719
val cls = Symbol.newClass(

0 commit comments

Comments
 (0)