Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class parameters, flags, and privateWithin and annotations to newClass in reflect API #21880

Merged
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,32 @@ object Symbols extends SymUtils {
newClassSymbol(owner, name, flags, completer, privateWithin, coord, compUnitInfo)
}

/** Same as the other `newNormalizedClassSymbol` except that `parents` can be a function returning a list of arbitrary
* types which get normalized into type refs and parameter bindings and annotations can be assigned in the completer.
*/
def newNormalizedClassSymbol(
owner: Symbol,
name: TypeName,
flags: FlagSet,
parentTypes: Symbol => List[Type],
selfInfo: Type,
privateWithin: Symbol,
annotations: List[Tree],
coord: Coord,
compUnitInfo: CompilationUnitInfo | Null)(using Context): ClassSymbol = {
def completer = new LazyType {
def complete(denot: SymDenotation)(using Context): Unit = {
val cls = denot.asClass.classSymbol
val decls = newScope
val parents = parentTypes(cls).map(_.dealias)
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
denot.info = ClassInfo(owner.thisType, cls, parents, decls, selfInfo)
denot.annotations = annotations.map(Annotations.Annotation(_))
}
}
newClassSymbol(owner, name, flags, completer, privateWithin, coord, compUnitInfo)
}

def newRefinedClassSymbol(coord: Coord = NoCoord)(using Context): ClassSymbol =
newCompleteClassSymbol(ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil, newScope, coord = coord)

Expand Down Expand Up @@ -706,6 +732,34 @@ object Symbols extends SymUtils {
privateWithin, coord, compUnitInfo)
}

/** Same as `newNormalizedModuleSymbol` except that `parents` can be a function returning a list of arbitrary
* types which get normalized into type refs and parameter bindings.
*/
def newNormalizedModuleSymbol(
owner: Symbol,
name: TermName,
modFlags: FlagSet,
clsFlags: FlagSet,
parentTypes: ClassSymbol => List[Type],
decls: Scope,
privateWithin: Symbol,
coord: Coord,
compUnitInfo: CompilationUnitInfo | Null)(using Context): TermSymbol = {
def completer(module: Symbol) = new LazyType {
def complete(denot: SymDenotation)(using Context): Unit = {
val cls = denot.asClass.classSymbol
val decls = newScope
val parents = parentTypes(cls).map(_.dealias)
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
denot.info = ClassInfo(owner.thisType, cls, parents, decls, TermRef(owner.thisType, module))
}
}
newModuleSymbol(
owner, name, modFlags, clsFlags,
(module, modcls) => completer(module),
privateWithin, coord, compUnitInfo)
}

/** Create a package symbol with associated package class
* from its non-info fields and a lazy type for loading the package's members.
*/
Expand Down
46 changes: 25 additions & 21 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -854,30 +854,34 @@ object TreeChecker {
val phases = ctx.base.allPhases.toList
val treeChecker = new LocalChecker(previousPhases(phases))

def reportMalformedMacroTree(msg: String | Null, err: Throwable) =
val stack =
if !ctx.settings.Ydebug.value then "\nstacktrace available when compiling with `-Ydebug`"
else if err.getStackTrace == null then " no stacktrace"
else err.getStackTrace.nn.mkString(" ", " \n", "")
report.error(
em"""Malformed tree was found while expanding macro with -Xcheck-macros.
|The tree does not conform to the compiler's tree invariants.
|
|Macro was:
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
|
|The macro returned:
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
|
|Error:
|$msg
|$stack
|""",
original
)

try treeChecker.typed(expansion)(using checkingCtx)
catch
case err: java.lang.AssertionError =>
val stack =
if !ctx.settings.Ydebug.value then "\nstacktrace available when compiling with `-Ydebug`"
else if err.getStackTrace == null then " no stacktrace"
else err.getStackTrace.nn.mkString(" ", " \n", "")

report.error(
em"""Malformed tree was found while expanding macro with -Xcheck-macros.
|The tree does not conform to the compiler's tree invariants.
|
|Macro was:
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
|
|The macro returned:
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
|
|Error:
|${err.getMessage}
|$stack
|""",
original
)
reportMalformedMacroTree(err.getMessage(), err)
case err: UnhandledError =>
reportMalformedMacroTree(err.diagnostic.message, err)

private[TreeChecker] def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
case (phase: MegaPhase) :: phases1 =>
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import collection.mutable
import reporting.*
import Checking.{checkNoPrivateLeaks, checkNoWildcard}
import cc.CaptureSet
import transform.Splicer

trait TypeAssigner {
import tpd.*
Expand Down Expand Up @@ -301,7 +302,10 @@ trait TypeAssigner {
if fntpe.isResultDependent then safeSubstMethodParams(fntpe, args.tpes)
else fntpe.resultType // fast path optimization
else
errorType(em"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
val erroringPhase =
if Splicer.inMacroExpansion then i"${ctx.phase} (while expanding macro)"
else ctx.phase.prev.toString
errorType(em"wrong number of arguments at $erroringPhase for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
case err: ErrorType =>
err
case t =>
Expand Down
179 changes: 174 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.quoted.runtime.impl.printers.*
import scala.reflect.TypeTest
import dotty.tools.dotc.core.NameKinds.ExceptionBinderName
import dotty.tools.dotc.transform.TreeChecker
import dotty.tools.dotc.core.Names
import dotty.tools.dotc.util.Spans.NoCoord

object QuotesImpl {

Expand Down Expand Up @@ -241,9 +243,35 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object ClassDef extends ClassDefModule:
def apply(cls: Symbol, parents: List[Tree], body: List[Statement]): ClassDef =
val untpdCtr = untpd.DefDef(nme.CONSTRUCTOR, Nil, tpd.TypeTree(dotc.core.Symbols.defn.UnitClass.typeRef), tpd.EmptyTree)
val paramsDefs: List[untpd.ParamClause] =
cls.primaryConstructor.paramSymss.map { paramSym =>
if paramSym.headOption.map(_.isType).getOrElse(false) then
paramSym.map(sym => TypeDef(sym))
else
paramSym.map(ValDef(_, None))
}
def throwError() =
throw new RuntimeException(
"Symbols necessary for creation of the ClassDef tree could not be found."
)
val paramsAccessDefs: List[untpd.ParamClause] =
cls.primaryConstructor.paramSymss.map { paramSym =>
if paramSym.headOption.map(_.isType).getOrElse(false) then
paramSym.map { symm =>
def isParamAccessor(memberSym: Symbol) = memberSym.flags.is(Flags.Param) && memberSym.name == symm.name
TypeDef(cls.typeMembers.find(isParamAccessor).getOrElse(throwError()))
}
else
paramSym.map { symm =>
def isParam(memberSym: Symbol) = memberSym.flags.is(Flags.ParamAccessor) && memberSym.name == symm.name
ValDef(cls.fieldMembers.find(isParam).getOrElse(throwError()), None)
}
}

val termSymbol: dotc.core.Symbols.TermSymbol = cls.primaryConstructor.asTerm
val untpdCtr = untpd.DefDef(nme.CONSTRUCTOR, paramsDefs, tpd.TypeTree(dotc.core.Symbols.defn.UnitClass.typeRef), tpd.EmptyTree)
val ctr = ctx.typeAssigner.assignType(untpdCtr, cls.primaryConstructor)
tpd.ClassDefWithParents(cls.asClass, ctr, parents, body)
tpd.ClassDefWithParents(cls.asClass, ctr, parents, paramsAccessDefs.flatten ++ body)

def copy(original: Tree)(name: String, constr: DefDef, parents: List[Tree], selfOpt: Option[ValDef], body: List[Statement]): ClassDef = {
val dotc.ast.Trees.TypeDef(_, originalImpl: tpd.Template) = original: @unchecked
Expand Down Expand Up @@ -2655,8 +2683,134 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for sym <- decls(cls) do cls.enter(sym)
cls

def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol],
selfType: Option[TypeRepr],
clsFlags: Flags,
clsPrivateWithin: Symbol,
conParams: List[(String, TypeRepr)]
): Symbol =
val (conParamNames, conParamTypes) = conParams.unzip
newClass(
owner,
name,
parents,
decls,
selfType,
clsFlags,
clsPrivateWithin,
Nil,
conMethodType = res => MethodType(conParamNames)(_ => conParamTypes, _ => res),
conFlags = Flags.EmptyFlags,
conPrivateWithin = Symbol.noSymbol,
conParamFlags = List(for i <- conParamNames yield Flags.EmptyFlags),
conParamPrivateWithins = List(for i <- conParamNames yield Symbol.noSymbol)
)

def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol],
selfType: Option[TypeRepr],
clsFlags: Flags,
clsPrivateWithin: Symbol,
clsAnnotations: List[Term],
conMethodType: TypeRepr => MethodOrPoly,
conFlags: Flags,
conPrivateWithin: Symbol,
conParamFlags: List[List[Flags]],
conParamPrivateWithins: List[List[Symbol]]
) =
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
assert(!conPrivateWithin.exists || conPrivateWithin.isType, "consPrivateWithin must be a type symbol or `Symbol.noSymbol`")
checkValidFlags(clsFlags.toTypeFlags, Flags.validClassFlags)
checkValidFlags(conFlags.toTermFlags, Flags.validClassConstructorFlags)
val cls = dotc.core.Symbols.newNormalizedClassSymbol(
owner,
name.toTypeName,
clsFlags,
parents,
selfType.getOrElse(Types.NoType),
clsPrivateWithin,
clsAnnotations,
NoCoord,
compUnitInfo = null
)
val methodType: MethodOrPoly = conMethodType(cls.typeRef)
def throwShapeException() = throw new Exception("Shapes of conMethodType and conParamFlags differ.")
def checkMethodOrPolyShape(checkedMethodType: TypeRepr, clauseIdx: Int): Unit =
checkedMethodType match
case PolyType(params, _, res) if clauseIdx == 0 =>
if (conParamFlags.length < clauseIdx) throwShapeException()
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
checkMethodOrPolyShape(res, clauseIdx + 1)
case PolyType(_, _, _) => throw new Exception("Clause interleaving not supported for constructors")
case MethodType(params, _, res) =>
if (conParamFlags.length <= clauseIdx) throwShapeException()
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
checkMethodOrPolyShape(res, clauseIdx + 1)
case other =>
xCheckMacroAssert(
other.typeSymbol == cls,
"Incorrect type returned from the innermost PolyOrMethod."
)
(other, methodType) match
case (AppliedType(tycon, args), pt: PolyType) =>
xCheckMacroAssert(
args.length == pt.typeParams.length &&
args.zip(pt.typeParams).forall {
case (arg, param) => arg == param.paramRef
},
"Constructor result type does not correspond to the declared type parameters"
)
case _ =>
xCheckMacroAssert(
!(other.isInstanceOf[AppliedType] || methodType.isInstanceOf[PolyType]),
"AppliedType has to be the innermost resultTypeExp result if and only if conMethodType returns a PolyType"
)
checkMethodOrPolyShape(methodType, clauseIdx = 0)

cls.enter(dotc.core.Symbols.newSymbol(cls, nme.CONSTRUCTOR, Flags.Synthetic | Flags.Method | conFlags, methodType, conPrivateWithin, dotty.tools.dotc.util.Spans.NoCoord))

case class ParamSymbolData(name: String, tpe: TypeRepr, isTypeParam: Boolean, clauseIdx: Int, elementIdx: Int)
def getParamSymbolsData(methodType: TypeRepr, clauseIdx: Int): List[ParamSymbolData] =
methodType match
case MethodType(paramInfosExp, resultTypeExp, res) =>
paramInfosExp.zip(resultTypeExp).zipWithIndex.map { case ((name, tpe), elementIdx) =>
ParamSymbolData(name, tpe, isTypeParam = false, clauseIdx, elementIdx)
} ++ getParamSymbolsData(res, clauseIdx + 1)
case pt @ PolyType(paramNames, paramBounds, res) =>
paramNames.zip(paramBounds).zipWithIndex.map {case ((name, tpe), elementIdx) =>
ParamSymbolData(name, tpe, isTypeParam = true, clauseIdx, elementIdx)
} ++ getParamSymbolsData(res, clauseIdx + 1)
case result =>
List()
// Maps PolyType indexes to type parameter symbol typerefs
val paramRefMap = collection.mutable.HashMap[Int, Symbol]()
val paramRefRemapper = new Types.TypeMap {
def apply(tp: Types.Type) = tp match {
case pRef: ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
case _ => mapOver(tp)
}
}
for case ParamSymbolData(name, tpe, isTypeParam, clauseIdx, elementIdx) <- getParamSymbolsData(methodType, 0) do
if isTypeParam then
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTypeFlags, Flags.validClassTypeParamFlags)
val symbol = dotc.core.Symbols.newSymbol(cls, name.toTypeName, Flags.Param | Flags.Deferred | Flags.Private | Flags.PrivateLocal | Flags.Local | conParamFlags(clauseIdx)(elementIdx), tpe, conParamPrivateWithins(clauseIdx)(elementIdx))
paramRefMap.addOne(elementIdx, symbol)
cls.enter(symbol)
else
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTermFlags, Flags.validClassTermParamFlags)
val fixedType = paramRefRemapper(tpe)
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor | conParamFlags(clauseIdx)(elementIdx), fixedType, conParamPrivateWithins(clauseIdx)(elementIdx)))
for sym <- decls(cls) do cls.enter(sym)
cls

def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: Symbol => List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
assert(!privateWithin.exists || privateWithin.isType, "privateWithin must be a type symbol or `Symbol.noSymbol`")
val mod = dotc.core.Symbols.newNormalizedModuleSymbol(
owner,
Expand All @@ -2665,7 +2819,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
clsFlags | dotc.core.Flags.ModuleClassCreationFlags,
parents,
dotc.core.Scopes.newScope,
privateWithin)
privateWithin,
NoCoord,
compUnitInfo = null
)
val cls = mod.moduleClass.asClass
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, Nil, Nil))
for sym <- decls(cls) do cls.enter(sym)
Expand Down Expand Up @@ -3063,6 +3220,18 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
// Keep: aligned with Quotes's `newTypeAlias` doc
private[QuotesImpl] def validTypeAliasFlags: Flags = Private | Protected | Override | Final | Infix | Local

// Keep: aligned with Quotes's `newClass`
private[QuotesImpl] def validClassFlags: Flags = Private | Protected | PrivateLocal | Local | Final | Trait | Abstract | Open

// Keep: aligned with Quote's 'newClass'
private[QuotesImpl] def validClassConstructorFlags: Flags = Synthetic | Method | Private | Protected | PrivateLocal | Local

// Keep: aligned with Quotes's `newClass`
private[QuotesImpl] def validClassTypeParamFlags: Flags = Param | Deferred | Private | PrivateLocal | Local

// Keep: aligned with Quotes's `newClass`
private[QuotesImpl] def validClassTermParamFlags: Flags = ParamAccessor | Private | Protected | PrivateLocal | Local

end Flags

given FlagsMethods: FlagsMethods with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1379,13 +1379,13 @@ object SourceCode {
printTypeTree(bounds.low)
else
bounds.low match {
case Inferred() =>
case Inferred() if bounds.low.tpe.typeSymbol == TypeRepr.of[Nothing].typeSymbol =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cosmetic change and it only hides when we infer Nothing (or Any for the second change). What was the motivation you had behind it?

Copy link
Contributor Author

@jchyb jchyb Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had an issue that when I built a tree like class SomeClass[T <: Int] and would call .show on it, it would only render class SomeClass[T]. It turned out that trees created with tpd.TypeTree are treated as Inferred in the Quotes reflect api, and if I wanted to actually render something, I needed this change here (I figure it might be better to have a printout a bit too big, than not big enough). I recall trying to test if it was possible to adjust the reflect Inferred tree definition instead, but I do not remember what the problem there was. Test case where this matters is newClassTypeParams.

case low =>
this += " >: "
printTypeTree(low)
}
bounds.hi match {
case Inferred() => this
case Inferred() if bounds.hi.tpe.typeSymbol == TypeRepr.of[Any].typeSymbol => this
case hi =>
this += " <: "
printTypeTree(hi)
Expand Down
Loading
Loading