Skip to content

Commit

Permalink
Merge pull request #737 from ahjohannessen/wip-varargs-syntax
Browse files Browse the repository at this point in the history
compiler: account for Scala 2 and 3 differences in generated varargs code
  • Loading branch information
mergify[bot] authored Feb 27, 2024
2 parents ace8440 + c38c2e9 commit 5fbb1ec
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 17 deletions.
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ project/plugins/project/
.scala_dependencies
.idea

.bloop/

compiler/version.properties

.vscode/

# Metals
.metals/
.bloop/
project/metals.sbt
metals.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public static Optional<File> compile(
List<String> constructorAnnotations,
Codec codec,
boolean inclusiveDot) {
String scalaVersion = play.twirl.compiler.BuildInfo$.MODULE$.scalaVersion();
Seq<String> scalaAdditionalImports = toScalaSeq(additionalImports);
Seq<String> scalaConstructorAnnotations = toScalaSeq(constructorAnnotations);

Expand All @@ -65,6 +66,7 @@ public static Optional<File> compile(
sourceDirectory,
generatedDirectory,
formatterType,
scala.Option.apply(scalaVersion),
scalaAdditionalImports,
scalaConstructorAnnotations,
codec,
Expand Down
167 changes: 159 additions & 8 deletions compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ case class GeneratedSourceVirtual(path: String) extends AbstractGeneratedSource

object TwirlCompiler {

// For constants that depend on Scala 2 or 3 mode.
private[compiler] class ScalaCompat(emitScala3Sources: Boolean) {
val varargSplicesSyntax: String =
if (emitScala3Sources) "*" else ": _*"
}

private[compiler] object ScalaCompat {
def apply(scalaVersion: Option[String]): ScalaCompat =
new ScalaCompat(scalaVersion.exists(_.startsWith("3.")))
}

def defaultImports(scalaVersion: String) = {
val implicits = if (scalaVersion.startsWith("3.")) {
Seq(
Expand Down Expand Up @@ -199,7 +210,30 @@ object TwirlCompiler {
constructorAnnotations: collection.Seq[String] = Nil,
codec: Codec = TwirlIO.defaultCodec,
inclusiveDot: Boolean = false
) = {
): Option[File] =
compile(
source,
sourceDirectory,
generatedDirectory,
formatterType,
None,
additionalImports,
constructorAnnotations,
codec,
inclusiveDot
)

def compile(
source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
scalaVersion: Option[String],
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String],
codec: Codec,
inclusiveDot: Boolean
): Option[File] = {
val resultType = formatterType + ".Appendable"
val (templateName, generatedSource) =
generatedFile(source, codec, sourceDirectory, generatedDirectory, inclusiveDot)
Expand All @@ -211,6 +245,7 @@ object TwirlCompiler {
relativePath(source),
resultType,
formatterType,
scalaVersion,
additionalImports,
constructorAnnotations,
inclusiveDot
Expand All @@ -232,7 +267,33 @@ object TwirlCompiler {
constructorAnnotations: collection.Seq[String] = Nil,
codec: Codec = TwirlIO.defaultCodec,
inclusiveDot: Boolean = false
) = {
): GeneratedSourceVirtual =
compileVirtual(
content,
source,
sourceDirectory,
resultType,
formatterType,
None,
additionalImports,
constructorAnnotations,
codec,
inclusiveDot
)

def compileVirtual(
content: String,
source: File,
sourceDirectory: File,
resultType: String,
formatterType: String,
scalaVersion: Option[String],
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String],
codec: Codec,
inclusiveDot: Boolean
): GeneratedSourceVirtual = {

val (templateName, generatedSource) = generatedFileVirtual(source, sourceDirectory, inclusiveDot)
val generated = parseAndGenerateCode(
templateName,
Expand All @@ -241,6 +302,7 @@ object TwirlCompiler {
relativePath(source),
resultType,
formatterType,
scalaVersion,
additionalImports,
constructorAnnotations,
inclusiveDot
Expand All @@ -262,7 +324,31 @@ object TwirlCompiler {
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String],
inclusiveDot: Boolean
) = {
): String = parseAndGenerateCode(
templateName,
content,
codec,
relativePath,
resultType,
formatterType,
None,
additionalImports,
constructorAnnotations,
inclusiveDot
)

private def parseAndGenerateCode(
templateName: Array[String],
content: Array[Byte],
codec: Codec,
relativePath: String,
resultType: String,
formatterType: String,
scalaVersion: Option[String],
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String],
inclusiveDot: Boolean
): String = {
val templateParser = new TwirlParser(inclusiveDot)
templateParser.parse(new String(content, codec.charSet)) match {
case templateParser.Success(parsed: Template, rest) if rest.atEnd() => {
Expand All @@ -274,6 +360,7 @@ object TwirlCompiler {
parsed,
resultType,
formatterType,
ScalaCompat(scalaVersion),
additionalImports,
constructorAnnotations
)
Expand Down Expand Up @@ -448,8 +535,29 @@ object TwirlCompiler {
formatterType: String,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): collection.Seq[Any] = generateCode(
packageName,
name,
root,
resultType,
formatterType,
ScalaCompat(None),
additionalImports,
constructorAnnotations
)

private def generateCode(
packageName: String,
name: String,
root: Template,
resultType: String,
formatterType: String,
scalaCompat: ScalaCompat,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): collection.Seq[Any] = {
val (renderCall, f, templateType) = TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType)
val (renderCall, f, templateType) =
TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType, scalaCompat)

// Get the imports that we need to include, filtering out empty imports
val imports: Seq[Any] = Seq(additionalImports.map(i => Seq("import ", i, "\n")), formatImports(root.topImports))
Expand Down Expand Up @@ -509,9 +617,42 @@ package """ :+ packageName :+ """
formatterType: String,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): String = generateFinalTemplate(
relativePath,
contents,
packageName,
name,
root,
resultType,
formatterType,
ScalaCompat(None),
additionalImports,
constructorAnnotations
)

private def generateFinalTemplate(
relativePath: String,
contents: Array[Byte],
packageName: String,
name: String,
root: Template,
resultType: String,
formatterType: String,
scalaCompat: ScalaCompat,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): String = {
val generated =
generateCode(packageName, name, root, resultType, formatterType, additionalImports, constructorAnnotations)
generateCode(
packageName,
name,
root,
resultType,
formatterType,
scalaCompat,
additionalImports,
constructorAnnotations
)

Source.finalSource(relativePath, contents, generated, Hash(contents, additionalImports))
}
Expand All @@ -531,7 +672,17 @@ package """ :+ packageName :+ """
}
}

def getFunctionMapping(signature: String, returnType: String): (String, String, String) = {
def getFunctionMapping(
signature: String,
returnType: String,
): (String, String, String) =
getFunctionMapping(signature, returnType, ScalaCompat(None))

private[compiler] def getFunctionMapping(
signature: String,
returnType: String,
sc: ScalaCompat
): (String, String, String) = {

val params: List[List[Term.Param]] =
try {
Expand Down Expand Up @@ -573,7 +724,7 @@ package """ :+ packageName :+ """
.map { p =>
p.name.toString + Option(p.decltpe.get.toString)
.filter(_.endsWith("*"))
.map(_ => ".toIndexedSeq:_*")
.map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}")
.getOrElse("")
}
.mkString(",") + ")"
Expand Down Expand Up @@ -601,7 +752,7 @@ package """ :+ packageName :+ """
.map { p =>
p.name.toString + Option(p.decltpe.get.toString)
.filter(_.endsWith("*"))
.map(_ => ".toIndexedSeq:_*")
.map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}")
.getOrElse("")
}
.mkString(",") + ")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ package play.twirl.compiler
package test

import java.io._
import play.twirl.parser.TwirlIO

object Helper {
case class CompilationError(message: String, line: Int, column: Int) extends RuntimeException(message)

class CompilerHelper(sourceDir: File, generatedDir: File, generatedClasses: File) {
class CompilerHelper(sourceDir: File, val generatedDir: File, generatedClasses: File) {
import java.net._
import scala.collection.mutable
import scala.reflect.internal.util.Position
Expand Down Expand Up @@ -99,7 +100,11 @@ object Helper {
sourceDir,
generatedDir,
"play.twirl.api.HtmlFormat",
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports
Option(scalaVersion),
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports,
constructorAnnotations = Nil,
codec = TwirlIO.defaultCodec,
inclusiveDot = false
)

val mapper = GeneratedSource(generated)
Expand Down
20 changes: 16 additions & 4 deletions compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import dotty.tools.io.PlainDirectory
import dotty.tools.io.Directory
import dotty.tools.io.ClassPath
import scala.jdk.CollectionConverters._
import play.twirl.parser.TwirlIO

object Helper {
case class CompilationError(message: String, line: Int, column: Int) extends RuntimeException(message)

class CompilerHelper(sourceDir: File, generatedDir: File, generatedClasses: File) {
class CompilerHelper(sourceDir: File, val generatedDir: File, generatedClasses: File) {
import java.net._
import scala.collection.mutable

Expand Down Expand Up @@ -60,14 +61,22 @@ object Helper {
): CompiledTemplate[T] = {
val scalaVersion = play.twirl.compiler.BuildInfo.scalaVersion
val templateFile = new File(sourceDir, templateName)
val Some(generated) = twirlCompiler.compile(
val generatedOpt: Option[File] = twirlCompiler.compile(
templateFile,
sourceDir,
generatedDir,
"play.twirl.api.HtmlFormat",
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports
Option(scalaVersion),
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports,
constructorAnnotations = Nil,
codec = TwirlIO.defaultCodec,
inclusiveDot = false
)

val generated = generatedOpt.getOrElse {
throw new FileNotFoundException(s"Could not find generated file for $templateName")
}

val mapper = GeneratedSource(generated)

val compilerArgs = Array(
Expand All @@ -94,7 +103,10 @@ object Helper {

class TestDriver(outDir: Path, compilerArgs: Array[String], path: Path) extends Driver {
def compile(): Reporter = {
val Some((toCompile, rootCtx)) = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh)
val setupOpt = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh)
val (toCompile, rootCtx) = setupOpt.getOrElse {
throw new Exception("Failed to initialize compiler")
}

val silentReporter = new ConsoleReporter.AbstractConsoleReporter {
def printMessage(msg: String): Unit = {
Expand Down
Loading

0 comments on commit 5fbb1ec

Please sign in to comment.