diff --git a/.gitignore b/.gitignore index 15a05753a0..57fa41912d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,12 @@ project/plugins/lib_managed/ project/plugins/src_managed/ /.idea/ /.idea_modules/ +.project +.classpath +.cache-main +.cache-tests +.tmpBin +bin *.iml sonatype.sbt tutorial/data/cofollows.tsv diff --git a/.travis.yml b/.travis.yml index 92e1fbb6ca..b98db2022d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,11 +28,11 @@ matrix: include: #BASE TESTS - scala: 2.11.11 - env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple" + env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple scalding-quotation" script: "scripts/run_test.sh" - scala: 2.12.3 - env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple" + env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple scalding-quotation" script: "scripts/run_test.sh" - scala: 2.11.11 diff --git a/build.sbt b/build.sbt index 0a8844be5a..451c034300 100644 --- a/build.sbt +++ b/build.sbt @@ -216,6 +216,7 @@ lazy val scalding = Project( .aggregate( scaldingArgs, scaldingDate, + scaldingQuotation, scaldingCore, scaldingCommons, scaldingAvro, @@ -242,6 +243,7 @@ lazy val scaldingAssembly = Project( .aggregate( scaldingArgs, scaldingDate, + scaldingQuotation, scaldingCore, scaldingCommons, scaldingAvro, @@ -312,6 +314,13 @@ lazy val scaldingBenchmarks = module("benchmarks") parallelExecution in Test := false ).dependsOn(scaldingCore) +lazy val scaldingQuotation = module("quotation").settings( + libraryDependencies ++= Seq( + "org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided", + "org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided" + ) +) + lazy val scaldingCore = module("core").settings( libraryDependencies ++= Seq( "cascading" % "cascading-core" % cascadingVersion, @@ -333,7 +342,7 @@ lazy val scaldingCore = module("core").settings( "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion % "provided"), addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full) -).dependsOn(scaldingArgs, scaldingDate, scaldingSerialization, maple) +).dependsOn(scaldingArgs, scaldingDate, scaldingSerialization, maple, scaldingQuotation) lazy val scaldingCommons = module("commons").settings( libraryDependencies ++= Seq( diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala new file mode 100644 index 0000000000..b998fec17c --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala @@ -0,0 +1,41 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait Liftables { + val c: Context + import c.universe.{ TypeName => _, _ } + + protected implicit val sourceLiftable: Liftable[Source] = Liftable { + case Source(path, line) => q"com.twitter.scalding.quotation.Source($path, $line)" + } + + protected implicit val projectionsLiftable: Liftable[Projections] = Liftable { + case p => q"com.twitter.scalding.quotation.Projections(${p.set})" + } + + protected implicit val typeNameLiftable: Liftable[TypeName] = Liftable { + case TypeName(name) => q"com.twitter.scalding.quotation.TypeName($name)" + } + + protected implicit val accessorLiftable: Liftable[Accessor] = Liftable { + case Accessor(name) => q"com.twitter.scalding.quotation.Accessor($name)" + } + + protected implicit val quotedLiftable: Liftable[Quoted] = Liftable { + case Quoted(source, call, fa) => q"com.twitter.scalding.quotation.Quoted($source, $call, $fa)" + } + + protected implicit val projectionLiftable: Liftable[Projection] = Liftable { + case p: Property => q"$p" + case p: TypeReference => q"$p" + } + + protected implicit val propertyLiftable: Liftable[Property] = Liftable { + case Property(path, accessor, tpe) => q"com.twitter.scalding.quotation.Property($path, $accessor, $tpe)" + } + + protected implicit val typeReferenceLiftable: Liftable[TypeReference] = Liftable { + case TypeReference(name) => q"com.twitter.scalding.quotation.TypeReference($name)" + } +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala new file mode 100644 index 0000000000..9adc475e56 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala @@ -0,0 +1,146 @@ +package com.twitter.scalding.quotation + +import scala.annotation.tailrec + +case class Accessor(asString: String) extends AnyVal +case class TypeName(asString: String) extends AnyVal + +sealed trait Projection { + def andThen(accessor: Accessor, typeName: TypeName): Projection = + Property(this, accessor, typeName) +} + +/** + * A reference of a type. If not nested within a `Property`, it means that all fields are used. + */ +final case class TypeReference(typeName: TypeName) extends Projection { + override def toString = typeName.asString.split('.').last +} + +/** + * A projection property (e.g. `Person.name`) + */ +final case class Property(path: Projection, accessor: Accessor, typeName: TypeName) extends Projection { + override def toString = s"$path.${accessor.asString}" +} + +/** + * Utility class to deal with a collection of projections. + */ +final class Projections private (val set: Set[Projection]) extends Serializable { + + /** + * Returns the projections that are based on `tpe` and limits projections + * to only properties that extend from `superClass`. + */ + def of(typeName: TypeName, superClass: Class[_]): Projections = { + + def byType(p: Projection) = { + @tailrec def loop(p: Projection): Boolean = + p match { + case TypeReference(`typeName`) => true + case TypeReference(_) => false + case Property(p, _, _) => loop(p) + } + loop(p) + } + + def bySuperClass(p: Projection): Option[Projection] = { + + def isSubclass(c: TypeName) = + try + superClass.isAssignableFrom(Class.forName(c.asString)) + catch { + case _: ClassNotFoundException => + false + } + + def loop(p: Projection): Either[Projection, Option[Projection]] = + p match { + case TypeReference(tpe) => + Either.cond(!isSubclass(tpe), None, p) + case p @ Property(path, name, tpe) => + loop(path) match { + case Left(_) => + Either.cond(!isSubclass(tpe), Some(p), p) + case Right(path) => + Right(path) + } + } + + loop(p) match { + case Left(path) => Some(path) + case Right(opt) => opt + } + } + + Projections(set.filter(byType).flatMap(bySuperClass)) + } + + /** + * Given a set of base projections, returns the projections based on them. + * + * For instance, given a quoted function + * `val contact = Quoted.function { (c: Contact) => c.contact }` + * and a call + * `(p: Person) => contact(p.name)` + * returns the projection + * `Person.name.contact` + */ + def basedOn(base: Set[Projection]): Projections = { + def loop(base: Projection, p: Projection): Option[Projection] = + p match { + case TypeReference(tpe) => + base match { + case TypeReference(`tpe`) => Some(p) + case Property(_, _, `tpe`) => Some(base) + case other => None + } + case Property(path, name, tpe) => + loop(base, path).map(Property(_, name, tpe)) + } + Projections { + set.flatMap { p => + base.flatMap(loop(_, p)) + } + } + } + + def ++(p: Projections) = + Projections(set ++ p.set) + + override def toString = + s"Projections(${set.mkString(", ")})" + + override def equals(other: Any) = + other match { + case other: Projections => set == other.set + case other => false + } + + override def hashCode = + 31 * set.hashCode +} + +object Projections { + val empty = apply(Set.empty) + + /** + * Creates a normalized projections collection. For instance, + * given two projections `Person.contact` and `Person.contact.phone`, + * creates a collection with only `Person.contact`. + */ + def apply(set: Set[Projection]) = { + @tailrec def isNested(p: Projection): Boolean = + p match { + case Property(path, acessor, property) => + set.contains(path) || isNested(path) + case _ => + false + } + new Projections(set.filter(!isNested(_))) + } + + def flatten(list: Iterable[Projections]): Projections = + list.foldLeft(empty)(_ ++ _) +} \ No newline at end of file diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala new file mode 100644 index 0000000000..d0e3a490f4 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala @@ -0,0 +1,106 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait ProjectionMacro extends TreeOps with Liftables { + val c: Context + import c.universe.{ TypeName => _, _ } + + def projections(params: List[Tree]): Tree = { + + def typeName(t: Tree) = + TypeName(t.symbol.typeSignature.typeSymbol.fullName) + + def accessor(m: TermName) = + Accessor(m.decodedName.toString) + + def typeReference(tpe: Type) = + TypeReference(TypeName(tpe.typeSymbol.fullName)) + + def isFunction(t: Tree) = + Option(t.symbol).map { + _.typeSignature + .erasure + .typeSymbol + .fullName + .contains("scala.Function") + }.getOrElse(false) + + val nestedList = + params.flatMap { + case param @ q"(..$inputs) => $body" => + + val inputSymbols = inputs.map(_.symbol).toSet + + object Projection { + def unapply(t: Tree): Option[Tree] = + t match { + + case q"$v.$m(..$params)" => unapply(v) + + case q"$v.$m" if t.symbol.isMethod => + + if (inputSymbols.contains(v.symbol)) { + val p = + TypeReference(typeName(v)) + .andThen(accessor(m), typeName(t)) + Some(q"$p") + } else + unapply(v).map { n => + q"$n.andThen(${accessor(m)}, ${typeName(t)})" + } + + case t if inputSymbols.contains(t.symbol) => + Some(q"${TypeReference(typeName(t))}") + + case _ => None + } + } + + def functionCall(func: Tree, params: List[Tree]) = { + val paramsProjecttions = params.flatMap(Projection.unapply) + q""" + $func match { + case f: com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections.basedOn($paramsProjecttions.toSet) + case _ => + com.twitter.scalding.quotation.Projections(Set(..$paramsProjecttions)) + } + """ + } + + collect(body) { + case q"$func.apply[..$t](..$params)" => + functionCall(func, params) + case q"$func(..$params)" if isFunction(func) => + functionCall(func, params) + case t @ Projection(p) => + q"com.twitter.scalding.quotation.Projections(Set($p))" + } + + case func if isFunction(func) => + val paramProjections = + func.symbol.typeSignature.typeArgs.dropRight(1) + .map(typeReference) + q""" + $func match { + case f: com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections + case _ => + com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) + } + """ :: Nil + + case method if method.symbol != null && method.symbol.isMethod => + val paramRefs = + method.symbol.asMethod.paramLists.flatten + .map(param => typeReference(param.typeSignature)) + q"${Projections(paramRefs.toSet)}" :: Nil + + case other => + Nil + } + + q"com.twitter.scalding.quotation.Projections.flatten($nestedList)" + } +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala new file mode 100644 index 0000000000..805c174b5f --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala @@ -0,0 +1,32 @@ +package com.twitter.scalding.quotation + +import java.io.File + +/** + * Meta information about a method call. + */ +case class Quoted(position: Source, text: Option[String], projections: Projections) { + override def toString = s"$position ${text.getOrElse("")}" +} + +object Quoted { + import language.experimental.macros + implicit def method: Quoted = macro QuotedMacro.method + + private[scalding] def internal: Quoted = macro QuotedMacro.internal + + def function[T1, U](f: T1 => U): Function1[T1, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, U](f: (T1, T2) => U): Function2[T1, T2, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, U](f: (T1, T2, T3) => U): Function3[T1, T2, T3, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, T4, U](f: (T1, T2, T3, T4) => U): Function4[T1, T2, T3, T4, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, T4, T5, U](f: (T1, T2, T3, T4, T5) => U): Function5[T1, T2, T3, T4, T5, U] with QuotedFunction = macro QuotedMacro.function +} + +case class Source(path: String, line: Int) { + def classFile = path.split(File.separator).last + override def toString = s"$classFile:$line" +} + +trait QuotedFunction { + def quoted: Quoted +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala new file mode 100644 index 0000000000..f9640032bf --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala @@ -0,0 +1,111 @@ +package com.twitter.scalding.quotation + +import language.experimental.macros +import scala.reflect.macros.blackbox.Context +import scala.reflect.internal.util.RangePosition +import scala.reflect.internal.util.OffsetPosition +import scala.reflect.macros.runtime.{ Context => ReflectContext } +import java.io.File + +class QuotedMacro(val c: Context) + extends TreeOps + with TextMacro + with ProjectionMacro + with Liftables { + import c.universe._ + + def internal: Tree = quoted + + def method: Tree = { + rejectScaldingSources + quoted + } + + private def quoted: Tree = + quoted( + c.asInstanceOf[ReflectContext] + .callsiteTyper + .context + .tree + .asInstanceOf[Tree]) + + val QuotedCompanion = q"com.twitter.scalding.quotation.Quoted" + + private def quoted(tree: Tree): Tree = { + val source = Source(tree.pos.source.path, tree.pos.line) + + find(tree) { t => + t.pos != NoPosition && t.pos.start <= c.enclosingPosition.start + }.flatMap { t => + collect(t) { + + // the start position of vals is wrong, so we workaround + case q"val $name = $body" => quoted(body) + + case q"$m.method" if m.symbol.fullName == classOf[Quoted].getName => + c.abort( + c.enclosingPosition, + "Quoted.method can be invoked only as an implicit parameter") + + case tree @ q"$instance.$method[..$t]" => + q"${Quoted(source, Some(callText(method, t)), Projections.empty)}" + + case tree @ q"$instance.$method[..$t](...$params)" => + q""" + $QuotedCompanion( + $source, + Some(${callText(method, t ++ params.flatten)}), + ${projections(params.flatten)}) + """ + + }.headOption + }.getOrElse { + q"${Quoted(source, None, Projections.empty)}" + } + } + + def function(f: Tree): Tree = { + val source = Source(f.pos.source.path, f.pos.line) + val text = paramsText(TermName("function"), f) + f match { + case q"(..$params) => $body" => + c.untypecheck { + q""" + new ${f.tpe.finalResultType} with ${c.symbolOf[QuotedFunction]} { + override def apply(..$params) = $body + override def quoted = + $QuotedCompanion( + $source, + Some($text), + ${projections(f :: Nil)} + ) + } + """ + } + case _ => + c.abort(f.pos, "Expected a function") + } + } + + private def rejectScaldingSources = { + + def whitelist = + Set("test", "example", "tutorial") + .exists(c.enclosingPosition.source.path.contains) + + def isScalding(sym: Symbol): Boolean = + sym.fullName.contains("com.twitter.scalding") || { + sym.owner match { + case NoSymbol => false + case owner => isScalding(owner) + } + } + + if (!whitelist && isScalding(c.internal.enclosingOwner)) + c.abort( + c.enclosingPosition, + "The quotation must happen at the level of the user-facing API. Add an `implicit q: Quoted` to the enclosing method. " + + "If that's not possible and the transformation doesn't introduce projections, use Quoted.internal.") + } +} + diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala new file mode 100644 index 0000000000..f5538c9969 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala @@ -0,0 +1,88 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait TextMacro { + val c: Context + import c.universe._ + + def callText(method: TermName, params: List[Tree]): String = + params.headOption.map(callText(method, _)).getOrElse(s"$method") + + def callText(method: TermName, firstParam: Tree): String = + s"$method${paramsText(method, firstParam)}" + + /* + * This should be something simple since Scala trees have the start and + * end positions. However, there's a bug that makes the positions unreliable. + * This method uses an ad-hoc parsing to get the text from the source file. + */ + def paramsText(method: TermName, firstParam: Tree): String = { + import c.universe._ + + val fileContent = c.enclosingPosition.source.content.mkString + + /* + * The start position of a tree isn't its actual start. It's necessary + * to find the minimum start of the nested trees, which is reliable. + */ + def start(t: Tree) = { + def loop(t: List[Tree]): List[Position] = + t.map(_.pos) ++ t.flatMap(t => loop(t.children)) + + loop(List(t)).filter(_ != NoPosition).map(_.start).min + } + + /* + * From the first parameter start position, walk back until the method + * call start and return the position immediately after the method name. + */ + val content = { + val reverseMethodName = + method.decodedName.toString.reverse + + def paramsStartPosition(content: String, pos: Int): Int = + if (content.startsWith(reverseMethodName) || content.isEmpty) + pos + else + paramsStartPosition(content.drop(1), pos - 1) + + val firstParamStart = start(firstParam) + + val newStart = + paramsStartPosition( + fileContent.take(firstParamStart).reverse, + firstParamStart) + + fileContent.drop(newStart).toList + } + + val blockDelimiters = + Map( + '(' -> ')', + '{' -> '}', + '[' -> ']') + + /* + * Reads the parameters block. It takes in consideration nested blocks like `map(v => { ... })` + */ + def readParams(chars: List[Char], open: List[Char], acc: List[Char] = Nil): (List[Char], List[Char]) = + chars match { + case Nil => + (acc, Nil) + case head :: tail => + blockDelimiters.get(head) match { + case Some(closing) => + val (block, rest) = readParams(tail, open :+ closing) + readParams(rest, open, acc ++ (head +: block :+ closing)) + case None => + if (head != ' ' && (open.isEmpty || head == open.last)) + (acc, tail) + else + readParams(tail, open, acc :+ head) + } + } + + readParams(content, Nil)._1.mkString + } +} \ No newline at end of file diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala new file mode 100644 index 0000000000..2c5a9b3c81 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala @@ -0,0 +1,46 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait TreeOps { + val c: Context + import c.universe._ + + /** + * Finds the first tree that satisfies the condition. + */ + def find(tree: Tree)(f: Tree => Boolean): Option[Tree] = { + var res: Option[Tree] = None + val t = new Traverser { + override def traverse(t: Tree) = { + if (res.isEmpty) + if (f(t)) + res = Some(t) + else + super.traverse(t) + } + } + t.traverse(tree) + res + } + + /** + * Similar to tree.collect but it doesn't collect the children of a + * collected tree. + */ + def collect[T](tree: Tree)(f: PartialFunction[Tree, T]): List[T] = { + var res = List[T]() + val t = new Traverser { + override def traverse(t: Tree) = { + f.lift(t) match { + case Some(v) => + res :+= v + case None => + super.traverse(t) + } + } + } + t.traverse(tree) + res + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala new file mode 100644 index 0000000000..0320bf8f7f --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala @@ -0,0 +1,25 @@ +package com.twitter.scalding.quotation + +class LimitationsTest extends Test { + + class TestClass { + def function[T, U](f: T => U)(implicit q: Quoted) = (q, f) + } + + val test = new TestClass + + "nested transitive projection" in pendingUntilFixed { + test.function[Person, Option[String]](_.alternativeContact.map(_.phone))._1.projections.set mustEqual + Set(Person.typeReference.andThen(Accessor("alternativeContact"), typeName[Option[Contact]]).andThen(Accessor("phone"), typeName[String])) + } + + "nested quoted function projection" in pendingUntilFixed { + val contactFunction = Quoted.function { + (p: Person) => p.contact + } + val phoneFunction = Quoted.function { + (p: Person) => contactFunction(p).phone + } + phoneFunction.quoted.projections.set mustEqual Set(Person.phoneProjection) + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala new file mode 100644 index 0000000000..f578c407ec --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala @@ -0,0 +1,11 @@ +package com.twitter.scalding.quotation + +case class Contact(phone: String) +case class Person(name: String, contact: Contact, alternativeContact: Option[Contact]) + +object Person { + val typeReference = TypeReference(typeName[Person]) + val nameProjection = typeReference.andThen(Accessor("name"), typeName[String]) + val contactProjection = typeReference.andThen(Accessor("contact"), typeName[Contact]) + val phoneProjection = contactProjection.andThen(Accessor("phone"), typeName[String]) +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala new file mode 100644 index 0000000000..fab42a2242 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala @@ -0,0 +1,105 @@ +package com.twitter.scalding.quotation + +import org.scalatest.Matchers +import org.scalatest.WordSpec +import org.scalatest.FreeSpec +import org.scalatest.MustMatchers + +class ProjectionMacroTest extends Test { + + class TestClass { + def function[T, U](f: T => U)(implicit m: Quoted) = (m, f) + def noProjection(i: Int)(implicit m: Quoted) = (m, i) + } + + val test = new TestClass + + "no projection" in { + test.noProjection(42)._1.projections.set mustEqual Set.empty + } + + "method with params isn't considered as projection" in { + test + .function[Person, String](_.name.substring(1))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + + "simple" in { + test.function[Person, String](_.name)._1 + .projections.set mustEqual Set(Person.nameProjection) + } + + "nested" in { + test.function[Person, String](_.contact.phone)._1 + .projections.set mustEqual Set(Person.phoneProjection) + } + + "all properties" in { + test.function[Person, Person](p => p)._1 + .projections.set mustEqual Set(Person.typeReference) + } + + "empty projection" in { + test.function[Person, Int](p => 1)._1 + .projections.set mustEqual Set.empty + } + + "function call" - { + "implicit apply" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](p => function(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](p => function(p))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + "explicit apply" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](p => function.apply(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](p => function.apply(p))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + } + + "function instance" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](function)._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](function)._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + + "method call" - { + "in the function body" in { + def method(p: Person) = p.name + test.function[Person, String](p => method(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "as function" in { + def method(p: Person) = p.name + test.function[Person, String](method)._1 + .projections.set mustEqual Set(Person.typeReference) + } + } +} diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala new file mode 100644 index 0000000000..690da72ffa --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala @@ -0,0 +1,168 @@ +package com.twitter.scalding.quotation + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +trait S + +trait T1 extends S +trait T2 + +trait P1 extends S +trait P2 + +class ProjectionTest extends Test { + + val t1 = TypeReference(typeName[T1]) + val p1 = Property(t1, Accessor("p1"), typeName[P1]) + + val t2 = TypeReference(TypeName(classOf[T2].getName)) + val p2 = Property(t2, Accessor("p2"), typeName[P2]) + + "Projection" - { + "andThen" - { + "TypeReference" in { + t1.andThen(p1.accessor, p1.typeName) mustEqual p1 + } + "Property" in { + p1.andThen(Accessor("p2"), TypeName("p2t")) mustEqual + Property(p1, Accessor("p2"), TypeName("p2t")) + } + } + + "toString" - { + "TypeReference" - { + "simple" in { + t1.toString mustEqual "T1" + } + "ignores package" in { + TypeReference(TypeName("com.twitter.Test1")).toString mustEqual "Test1" + } + } + "Property" in { + p1.toString() mustEqual "T1.p1" + } + } + } + + "Projections" - { + "empty" in { + Projections.empty.set mustEqual Set() + } + "apply" - { + "simple" in { + val set = Set[Projection](p1) + Projections(set).set mustEqual set + } + "paths merge" - { + "simple" in { + val set = Set[Projection](p1, t1) + Projections(set).set mustEqual Set(t1) + } + "nested" in { + val px = p1.andThen(Accessor("x"), TypeName("X")) + val set = Set[Projection](px, t1) + Projections(set).set mustEqual Set(t1) + } + } + } + "flatten" - { + "empty" in { + Projections.flatten(Nil).set mustEqual Set() + } + "non-empty" in { + val list = List( + Projections(Set(p1)), + Projections(Set(p2))) + Projections.flatten(list).set mustEqual Set(p1, p2) + } + "non-empty with merge" in { + val list = List( + Projections(Set(t1)), + Projections(Set(p1))) + Projections.flatten(list).set mustEqual Set(t1) + } + } + + "++" - { + "simple" in { + val p = Projections(Set(p1)) ++ Projections(Set(p2)) + p.set mustEqual Set(p1, p2) + } + "with merge" in { + val list = List( + Projections(Set(p1)), + Projections(Set(t1))) + Projections.flatten(list).set mustEqual Set(t1) + } + } + + "toString" - { + "empty" in { + Projections.empty.toString mustEqual "Projections()" + } + "non-empty" in { + Projections(Set(p1, p2)).toString mustEqual "Projections(T1.p1, T2.p2)" + } + } + + "basedOn" - { + "empty base" in { + Projections(Set(p1, p2)).basedOn(Set.empty) mustEqual + Projections.empty + } + "no match" in { + Projections(Set(p1, p2)).basedOn(Set(TypeReference(TypeName("X")))) mustEqual + Projections.empty + } + "one match" in { + val px1 = Property(TypeReference(TypeName("X")), Accessor("px"), typeName[T1]) + Projections(Set(p1, p2)).basedOn(Set(px1)).set mustEqual + Set(p1.copy(path = px1)) + } + "multiple matches" in { + val px1 = Property(TypeReference(TypeName("X1")), Accessor("px1"), typeName[T1]) + val px2 = Property(TypeReference(TypeName("X1")), Accessor("px2"), typeName[T2]) + Projections(Set(p1, p2)).basedOn(Set(px1, px2)).set mustEqual + Set(p1.copy(path = px1), p2.copy(path = px2)) + } + "partial match" in { + val px1 = Property(TypeReference(TypeName("X1")), Accessor("px1"), typeName[T1]) + val px2 = Property(TypeReference(TypeName("X1")), Accessor("px2"), TypeName("TX")) + Projections(Set(p1, p2)).basedOn(Set(px1, px2)).set mustEqual + Set(p1.copy(path = px1)) + } + } + + "of" - { + "byType" - { + "matches" in { + Projections(Set(t1)).of(t1.typeName, classOf[Any]).set mustEqual + Set(t1) + } + "doesn't match" in { + Projections(Set(t1)).of(TypeName("X"), classOf[Any]).set mustEqual + Set.empty + } + "nested" in { + val px = Property(p1, Accessor("px"), TypeName("PX")) + Projections(Set(px)).of(typeName[T1], classOf[Any]).set mustEqual + Set(px) + } + } + "bySuperClass" - { + "filters only projections of the super class type" in { + val px = p1.andThen(Accessor("px"), typeName[String]) + val py = px.andThen(Accessor("isEmpty"), typeName[Boolean]) + Projections(Set(py)).of(t1.typeName, classOf[S]).set mustEqual Set(px) + } + "ignores if class can't be loaded" in { + val tx = TypeReference(TypeName("TX")) + Projections(Set(tx)).of(tx.typeName, classOf[Any]).set mustEqual + Set.empty + } + } + } + + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala new file mode 100644 index 0000000000..a6dcac7532 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala @@ -0,0 +1,71 @@ +package com.twitter.scalding.quotation + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +class QuotedMacroTest extends Test { + + val test = new TestClass + + val nullary = test.nullary + val parametrizedNullary = test.parametrizedNullary[Int] + val withParam = test.withParam[Person, String](_.name)._1 + + val quotedFunction = + Quoted.function[Person, Contact](_.contact) + + val nestedQuotedFuction = + Quoted.function[Person, Contact](p => quotedFunction(p)) + + val person = Person("John", Contact("33223"), None) + + class TestClass { + def nullary(implicit q: Quoted) = q + def parametrizedNullary[T](implicit q: Quoted) = q + def withParam[T, U](f: T => U)(implicit q: Quoted) = (q, f) + } + + "quoted method" - { + + "nullary" in { + nullary.position.toString mustEqual "QuotedMacroTest.scala:10" + nullary.projections.set mustEqual Set.empty + nullary.text mustEqual Some("nullary") + } + + "parametrizedNullary" in { + parametrizedNullary.position.toString mustEqual "QuotedMacroTest.scala:11" + parametrizedNullary.projections.set mustEqual Set.empty + parametrizedNullary.text mustEqual Some("parametrizedNullary[Int]") + } + + "withParam" in { + withParam.position.toString mustEqual "QuotedMacroTest.scala:12" + withParam.projections.set mustEqual Set(Person.nameProjection) + withParam.text mustEqual Some("withParam[Person, String](_.name)") + } + } + + "quoted function" - { + "simple" in { + val q = quotedFunction.quoted + q.position.toString mustEqual "QuotedMacroTest.scala:15" + q.projections.set mustEqual Set(Person.contactProjection) + q.text mustEqual Some("[Person, Contact](_.contact)") + + quotedFunction(person) mustEqual person.contact + } + "nested" in { + val q = nestedQuotedFuction.quoted + q.position.toString mustEqual "QuotedMacroTest.scala:18" + q.projections.set mustEqual Set(Person.contactProjection) + q.text mustEqual Some("[Person, Contact](p => quotedFunction(p))") + + nestedQuotedFuction(person) mustEqual person.contact + } + } + + "invalid quoted method call" in { + "Quoted.method" mustNot compile + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala new file mode 100644 index 0000000000..eea5e65f89 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala @@ -0,0 +1,114 @@ +package com.twitter.scalding.quotation + +import org.scalatest.Matchers +import org.scalatest.WordSpec +import org.scalatest.FreeSpec +import org.scalatest.MustMatchers + +class TextMacroTest extends Test { + + class TestClass { + def nullary(implicit m: Quoted) = m + def parametrizedNullary[T](implicit m: Quoted) = m + def primitiveParam(v: Int)(implicit m: Quoted) = (m, v) + def parametrized[T](v: T)(implicit m: Quoted) = (m, v) + def paramGroups(a: Int, b: Int)(c: Int)(implicit m: Quoted) = (m, a, b, c) + def parametrizedParamGroups[T](a: T, b: Int)(c: T)(implicit m: Quoted) = (m, a, b, c) + def paramGroupsWithFunction(a: Int)(b: Int => Int)(implicit m: Quoted) = (m, a, b) + def function(f: Int => Int)(implicit m: Quoted) = (m, f) + def multipleFunctions[T, U, V](f1: T => U, f2: U => V)(implicit m: Quoted) = (m, f1, f2) + def tupleParam(t: (Int, Int))(implicit m: Quoted) = (m, t) + } + + val test = new TestClass + + "nullary" in { + test.nullary.text mustEqual + Some("nullary") + } + + "parametrizedNullary" - { + "inferred type param" in { + test.parametrizedNullary.text mustEqual + Some("parametrizedNullary") + } + "explicit type param" in { + test.parametrizedNullary[Int].text mustEqual + Some("parametrizedNullary[Int]") + } + } + + "primitiveParam" in { + test.primitiveParam(22)._1.text mustEqual + Some("primitiveParam(22)") + } + + "parametrized" - { + "inferred type param" in { + test.parametrized(42)._1.text mustEqual + Some("parametrized(42)") + } + "explicit type param" in { + test.parametrized[Int](42)._1.text mustEqual + Some("parametrized[Int](42)") + } + } + + "paramGroups" - { + "primitives" in { + test.paramGroups(1, 2)(3)._1.text mustEqual + Some("paramGroups(1, 2)(3)") + } + "parametrized" - { + "explicit type param" in { + test.parametrizedParamGroups[Int](1, 2)(3)._1.text mustEqual + Some("parametrizedParamGroups[Int](1, 2)(3)") + } + "inferred type param" in { + test.parametrizedParamGroups(1, 2)(3)._1.text mustEqual + Some("parametrizedParamGroups(1, 2)(3)") + } + } + "with function" in { + (test.paramGroupsWithFunction(1) { + case 1 => 2 + case _ => 3 + })._1.text mustEqual + Some("""paramGroupsWithFunction(1) { + case 1 => 2 + case _ => 3 + }""") + } + } + + "function" - { + "underscore" in { + test.function(_ + 1)._1.text mustEqual + Some("function(_ + 1)") + } + "pattern matching" in { + test.function { case _ => 4 }._1.text mustEqual Some("function { case _ => 4 }") + } + "curly braces" in { + test.function { _ + 1 }._1.text mustEqual Some("function { _ + 1 }") + } + } + + "complex tree" in { + val c = test.function { + def test = 1 + _ + 1 + } + c._1.text mustEqual + Some( + """function { + def test = 1 + _ + 1 + }""") + } + + "tuple param" in { + test.tupleParam((1, 2))._1.text mustEqual + Some("tupleParam((1, 2))") + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala new file mode 100644 index 0000000000..401b523788 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala @@ -0,0 +1,9 @@ +package com.twitter.scalding + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +package object quotation { + def typeName[T](implicit ct: reflect.ClassTag[T]) = TypeName(ct.runtimeClass.getName) + trait Test extends FreeSpec with MustMatchers +}