diff --git a/.travis.yml b/.travis.yml index 2fb94e5371..119a0863d7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,83 +7,100 @@ script: matrix: include: #BASE TESTS - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-args scalding-date" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-args scalding-date" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-avro scalding-hraven scalding-commons" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-avro scalding-hraven scalding-commons" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-core" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-core" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-hadoop-test" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-hadoop-test" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-jdbc scalding-json" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-jdbc scalding-json" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 + env: BUILD="base" TEST_TARGET="scalding-macros" + script: "scripts/run_test.sh" + + - scala: 2.11.5 + env: BUILD="base" TEST_TARGET="scalding-macros" + script: "scripts/run_test.sh" + +# not committed yet + # - scala: 2.10.5 + # env: BUILD="base" TEST_TARGET="scalding-commons-macros" + # script: "scripts/run_test.sh" + + # - scala: 2.11.5 + # env: BUILD="base" TEST_TARGET="scalding-commons-macros" + # script: "scripts/run_test.sh" + + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-parquet scalding-parquet-scrooge" script: "scripts/run_test.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="base" TEST_TARGET="scalding-parquet scalding-parquet-scrooge" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="base" TEST_TARGET="scalding-repl" script: "scripts/run_test.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="test tutorials" script: - "scripts/build_assembly_no_test.sh scalding" - "scripts/test_tutorials.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="test tutorials" script: - "scripts/build_assembly_no_test.sh scalding" - "scripts/test_tutorials.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="test matrix tutorials" script: - "scripts/build_assembly_no_test.sh scalding" - "scripts/test_matrix_tutorials.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="test matrix tutorials" script: - "scripts/build_assembly_no_test.sh scalding" - "scripts/test_matrix_tutorials.sh" - - scala: 2.10.4 + - scala: 2.10.5 env: BUILD="test repl and typed tutorials" script: - "scripts/build_assembly_no_test.sh scalding-repl" @@ -91,11 +108,14 @@ matrix: - "scripts/build_assembly_no_test.sh scalding-core" - "scripts/test_typed_tutorials.sh" - - scala: 2.11.4 + - scala: 2.11.5 env: BUILD="test typed tutorials" script: - "scripts/build_assembly_no_test.sh scalding-core" - "scripts/test_typed_tutorials.sh" -notifications: - irc: "chat.freenode.net#scalding" + - scala: 2.10.5 + env: BUILD="test execution tutorials" + script: + - "scripts/build_assembly_no_test.sh execution-tutorial" + - "scripts/test_execution_tutorial.sh" diff --git a/CHANGES.md b/CHANGES.md index a6bdf9a5d0..c585238a4b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,57 @@ # Scalding # +### Version 0.15.0 ### +* Move OrderedSerialization into zero-dep scalding-serialization module #1289 +* bump elephantbird to 4.8 #1292 +* Fix OrderedSerialization for some forked graphs #1293 +* Add serialization modules to aggregate list #1298 + +### Version 0.14.0 ### +* add .unit to Execution object #1189 +* Override hashCode for Args #1190 +* Put a value in a exception message #1191 +* Add an exclusiveUpper method to DateRange #1194 +* Covert LzoTextDelimited to Cascading scheme. #1179 +* Remove Travis IRC notifications #1200 +* add LookupJoin and LookupJoinTest changes from summingbird #1199 +* Add a new ExecutionApp tutorial #1196 +* Move main simple example to be the typed API, and put the .'s at the sta... #1193 +* Add Execution.withArgs #1205 +* Config/Cascading updater #1197 +* Remove algebird serializers #1206 +* remove warnings in CumulativeSum #1215 +* Implicit execution context / easier switching between modes #1113 +* add row l1 normalize #1214 +* provide Args as an implicit val #1219 +* call sourceConfInit when reading from taps in local mode #1228 +* Add distinctCount and distinctValues helper methods to KeyedList. #1232 +* import hygiene: remove unused imports and remove JavaConversions use #1239 +* Swap hash and filename for filename-extension-sensitive code #1243 +* Remove more unused imports #1240 +* Provide useHdfsLocalMode for an easy switch to mapreduce local mode #1244 +* upgrade scalacheck and scalatest #1246 +* Optimize string and (hopefully) number comparisons a bit #1241 +* Note the active FlowProcess for Joiners #1235 +* Make sure Executions are executed at most once #1253 +* Fix Config.getUniqueIDs #1254 +* Add MustHasReducers trait. #1252 +* Make sure the EvalCache thread isDaemon #1255 +* Use non-regex split function #1251 +* make InputSizeReducerEstimator work for any CompositeTap #1256 +* TimePathedSource helper methods #1257 +* Fix for reducer estimation not working correctly if withReducers is set to 1 reducer #1263 +* Add make(dest) to TypedPipe #1217 +* Fix SimpleDateFormat caching by default #1265 +* upgrade sbt and sbt launcher script #1270 +* Add TypedPipeDiff for comparing typed pipes #1266 +* Change separator from \1 to \u0001 #1271 +* Disable reducer estimation for map-only steps #1276 +* Local sources support multiple paths #1275 +* fix the spelling of the cumulativeSumTest file #1281 +* Hydrate both sides of sampledCounts in skewJoinWithSmaller #1278 +* Bijection 0.8.0, algebird 0.10.0, chill 0.6.0, scala 2.10.5 #1287 +* Remove some deprecated items #1288 + ### Version 0.13.1 ### * Back out 4 changes to be binary compatible: https://github.com/twitter/scalding/pull/1187 * Use java.util.Random instead of scala.util.Random: https://github.com/twitter/scalding/pull/1186 diff --git a/README.md b/README.md index 7f86144067..39681da214 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Scalding is a Scala library that makes it easy to specify Hadoop MapReduce jobs. ![Scalding Logo](https://raw.github.com/twitter/scalding/develop/logo/scalding.png) -Current version: `0.13.1` +Current version: `0.15.0` ## Word Count @@ -37,16 +37,17 @@ You can find more example code under [examples/](https://github.com/twitter/scal ## Documentation and Getting Started * [**Getting Started**](https://github.com/twitter/scalding/wiki/Getting-Started) page on the [Scalding Wiki](https://github.com/twitter/scalding/wiki) +* [Scalding Scaladocs](http://twitter.github.com/scalding) provide details beyond the API References. Prefer using this as it's always up to date. * [**REPL in Wonderland**](https://gist.github.com/johnynek/a47699caa62f4f38a3e2) a hands-on tour of the scalding REPL requiring only git and java installed. * [**Runnable tutorials**](https://github.com/twitter/scalding/tree/master/tutorial) in the source. * The API Reference, including many example Scalding snippets: * [Type-safe API Reference](https://github.com/twitter/scalding/wiki/Type-safe-api-reference) * [Fields-based API Reference](https://github.com/twitter/scalding/wiki/Fields-based-API-Reference) -* [Scalding Scaladocs](http://twitter.github.com/scalding) provide details beyond the API References * The Matrix Library provides a way of working with key-attribute-value scalding pipes: * The [Introduction to Matrix Library](https://github.com/twitter/scalding/wiki/Introduction-to-Matrix-Library) contains an overview and a "getting started" example * The [Matrix API Reference](https://github.com/twitter/scalding/wiki/Matrix-API-Reference) contains the Matrix Library API reference with examples +* [**Introduction to Scalding Execution**](https://github.com/twitter/scalding/wiki/Calling-Scalding-from-inside-your-application) contains general rules and examples of calling Scalding from inside another application. Please feel free to use the beautiful [Scalding logo](https://drive.google.com/folderview?id=0B3i3pDi3yVgNbm9pMUdDcHFKVEk&usp=sharing) artwork anywhere. @@ -124,6 +125,10 @@ Thanks for assistance and contributions: * Sam Ritchie * Aaron Siegel: +* Ian O'Connell +* Alex Levenson +* Jonathan Coveney +* Kevin Lin * Brad Greenlee: * Edwin Chen * Arkajit Dey: @@ -133,9 +138,9 @@ Thanks for assistance and contributions: * Ning Liang * Dmitriy Ryaboy * Dong Wang -* Kevin Lin * Josh Attenberg * Juliet Hougland +* Eddie Xie A full list of [contributors](https://github.com/twitter/scalding/graphs/contributors) can be found on GitHub. diff --git a/maple/src/main/java/com/twitter/maple/hbase/HBaseScheme.java b/maple/src/main/java/com/twitter/maple/hbase/HBaseScheme.java index 6ee34d5404..0f830ede86 100644 --- a/maple/src/main/java/com/twitter/maple/hbase/HBaseScheme.java +++ b/maple/src/main/java/com/twitter/maple/hbase/HBaseScheme.java @@ -31,7 +31,6 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.RecordReader; -import org.mortbay.log.Log; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/maple/src/main/java/com/twitter/maple/hbase/HBaseTap.java b/maple/src/main/java/com/twitter/maple/hbase/HBaseTap.java index f65b3db0b3..37ebfb0a8e 100644 --- a/maple/src/main/java/com/twitter/maple/hbase/HBaseTap.java +++ b/maple/src/main/java/com/twitter/maple/hbase/HBaseTap.java @@ -17,7 +17,6 @@ import cascading.flow.FlowProcess; import cascading.tap.SinkMode; import cascading.tap.Tap; -import cascading.tap.hadoop.io.HadoopTupleEntrySchemeCollector; import cascading.tap.hadoop.io.HadoopTupleEntrySchemeIterator; import cascading.tuple.TupleEntryCollector; import cascading.tuple.TupleEntryIterator; @@ -33,10 +32,8 @@ import org.apache.hadoop.mapred.RecordReader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import sun.reflect.generics.reflectiveObjects.NotImplementedException; import java.io.IOException; -import java.util.Map.Entry; import java.util.UUID; /** diff --git a/project/Build.scala b/project/Build.scala index c349eafed7..6d40f82720 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -20,34 +20,38 @@ object ScaldingBuild extends Build { } def isScala210x(scalaVersion: String) = scalaBinaryVersion(scalaVersion) == "2.10" - val scalaTestVersion = "2.2.2" - val scalaCheckVersion = "1.11.5" - val hadoopVersion = "1.2.1" - val algebirdVersion = "0.9.0" - val bijectionVersion = "0.7.2" - val chillVersion = "0.5.2" - val slf4jVersion = "1.6.6" - val parquetVersion = "1.6.0rc4" + val algebirdVersion = "0.10.1" + val avroVersion = "1.7.4" + val bijectionVersion = "0.8.0" + val cascadingAvroVersion = "2.1.2" + val chillVersion = "0.6.0" val dfsDatastoresVersion = "1.3.4" + val elephantbirdVersion = "4.8" + val hadoopLzoVersion = "0.4.16" + val hadoopVersion = "1.2.1" val hbaseVersion = "0.94.10" val hravenVersion = "0.9.13" val jacksonVersion = "2.4.2" + val json4SVersion = "3.2.11" + val paradiseVersion = "2.0.1" + val parquetVersion = "1.6.0rc4" val protobufVersion = "2.4.1" - val elephantbirdVersion = "4.6" - val hadoopLzoVersion = "0.4.16" + val quasiquotesVersion = "2.0.1" + val scalaCheckVersion = "1.12.2" + val scalaTestVersion = "2.2.4" + val scalameterVersion = "0.6" + val scroogeVersion = "3.17.0" + val slf4jVersion = "1.6.6" val thriftVersion = "0.5.0" - val cascadingAvroVersion = "2.1.2" - val avroVersion = "1.7.4" - val json4SVersion = "3.2.11" val printDependencyClasspath = taskKey[Unit]("Prints location of the dependencies") val sharedSettings = Project.defaultSettings ++ assemblySettings ++ scalariformSettings ++ Seq( organization := "com.twitter", - scalaVersion := "2.10.4", + scalaVersion := "2.10.5", - crossScalaVersions := Seq("2.10.4", "2.11.5"), + crossScalaVersions := Seq("2.10.5", "2.11.5"), ScalariformKeys.preferences := formattingPreferences, @@ -201,7 +205,10 @@ object ScaldingBuild extends Build { scaldingJdbc, scaldingHadoopTest, scaldingMacros, - maple + maple, + executionTutorial, + scaldingSerialization, + scaldingSerializationMacros ) lazy val formattingPreferences = { @@ -221,7 +228,7 @@ object ScaldingBuild extends Build { Some(subProj) .filterNot(unreleasedModules.contains(_)) .map { - s => "com.twitter" % ("scalding-" + s + "_2.10") % "0.13.0" + s => "com.twitter" % ("scalding-" + s + "_2.10") % "0.15.0" } def module(name: String) = { @@ -242,6 +249,15 @@ object ScaldingBuild extends Build { lazy val cascadingJDBCVersion = System.getenv.asScala.getOrElse("SCALDING_CASCADING_JDBC_VERSION", "2.6.0") + lazy val scaldingBenchmarks = module("benchmarks").settings( + libraryDependencies ++= Seq( + "com.storm-enroute" %% "scalameter" % scalameterVersion % "test", + "org.scalacheck" %% "scalacheck" % scalaCheckVersion % "test" + ), + testFrameworks += new TestFramework("org.scalameter.ScalaMeterFramework"), + parallelExecution in Test := false + ).dependsOn(scaldingCore, scaldingMacros) + lazy val scaldingCore = module("core").settings( libraryDependencies ++= Seq( "cascading" % "cascading-core" % cascadingVersion, @@ -249,6 +265,7 @@ object ScaldingBuild extends Build { "cascading" % "cascading-hadoop" % cascadingVersion, "com.twitter" %% "chill" % chillVersion, "com.twitter" % "chill-hadoop" % chillVersion, + "com.twitter" %% "chill-algebird" % chillVersion, "com.twitter" % "chill-java" % chillVersion, "com.twitter" %% "bijection-core" % bijectionVersion, "com.twitter" %% "algebird-core" % algebirdVersion, @@ -257,7 +274,7 @@ object ScaldingBuild extends Build { "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion % "provided" ) - ).dependsOn(scaldingArgs, scaldingDate, maple) + ).dependsOn(scaldingArgs, scaldingDate, scaldingSerialization, maple) lazy val scaldingCommons = module("commons").settings( libraryDependencies ++= Seq( @@ -273,10 +290,12 @@ object ScaldingBuild extends Build { "com.hadoop.gplcompression" % "hadoop-lzo" % hadoopLzoVersion, // TODO: split this out into scalding-thrift "org.apache.thrift" % "libthrift" % thriftVersion, + // TODO: split this out into a scalding-scrooge + "com.twitter" %% "scrooge-serializer" % scroogeVersion % "provided", "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion % "provided" ) - ).dependsOn(scaldingArgs, scaldingDate, scaldingCore) + ).dependsOn(scaldingArgs, scaldingDate, scaldingCore, scaldingHadoopTest % "test") lazy val scaldingAvro = module("avro").settings( libraryDependencies ++= Seq( @@ -288,7 +307,7 @@ object ScaldingBuild extends Build { ).dependsOn(scaldingCore) lazy val scaldingParquet = module("parquet").settings( - libraryDependencies ++= Seq( + libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( // see https://issues.apache.org/jira/browse/PARQUET-143 for exclusions "com.twitter" % "parquet-cascading" % parquetVersion exclude("com.twitter", "parquet-pig") @@ -296,9 +315,12 @@ object ScaldingBuild extends Build { exclude("com.twitter.elephantbird", "elephant-bird-core"), "org.apache.thrift" % "libthrift" % "0.7.0", "org.slf4j" % "slf4j-api" % slf4jVersion, - "org.apache.hadoop" % "hadoop-core" % hadoopVersion % "provided" - ) - ).dependsOn(scaldingCore) + "org.apache.hadoop" % "hadoop-core" % hadoopVersion % "provided", + "org.scala-lang" % "scala-reflect" % scalaVersion, + "com.twitter" %% "bijection-macros" % bijectionVersion + ) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % quasiquotesVersion) else Seq()) + }, addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full)) + .dependsOn(scaldingCore, scaldingHadoopTest) def scaldingParquetScroogeDeps(version: String) = { if (isScala210x(version)) @@ -372,6 +394,17 @@ object ScaldingBuild extends Build { run <<= (run in Unprovided) ) + // zero dependency serialization module + lazy val scaldingSerialization = module("serialization") + lazy val scaldingSerializationMacros = module("serialization-macros").settings( + libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( + "org.scala-lang" % "scala-library" % scalaVersion, + "org.scala-lang" % "scala-reflect" % scalaVersion + ) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % "2.0.1") else Seq()) + }, + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full) + ).dependsOn(scaldingSerialization) + lazy val scaldingJson = module("json").settings( libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( "org.apache.hadoop" % "hadoop-core" % hadoopVersion % "provided", @@ -395,22 +428,23 @@ object ScaldingBuild extends Build { libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( ("org.apache.hadoop" % "hadoop-core" % hadoopVersion), ("org.apache.hadoop" % "hadoop-minicluster" % hadoopVersion), + "com.twitter" %% "chill-algebird" % chillVersion, "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "org.scalacheck" %% "scalacheck" % scalaCheckVersion, "org.scalatest" %% "scalatest" % scalaTestVersion ) } - ).dependsOn(scaldingCore) + ).dependsOn(scaldingCore, scaldingSerializationMacros % "test") lazy val scaldingMacros = module("macros").settings( libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( "org.scala-lang" % "scala-library" % scalaVersion, "org.scala-lang" % "scala-reflect" % scalaVersion, "com.twitter" %% "bijection-macros" % bijectionVersion - ) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % "2.0.1") else Seq()) + ) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % quasiquotesVersion) else Seq()) }, - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full) + addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full) ).dependsOn(scaldingCore, scaldingHadoopTest) // This one uses a different naming convention @@ -430,4 +464,21 @@ object ScaldingBuild extends Build { ) } ) + + lazy val executionTutorial = Project( + id = "execution-tutorial", + base = file("tutorial/execution-tutorial"), + settings = sharedSettings + ).settings( + name := "execution-tutorial", + libraryDependencies <++= (scalaVersion) { scalaVersion => Seq( + "org.scala-lang" % "scala-library" % scalaVersion, + "org.scala-lang" % "scala-reflect" % scalaVersion, + "org.apache.hadoop" % "hadoop-core" % hadoopVersion, + "org.slf4j" % "slf4j-api" % slf4jVersion, + "org.slf4j" % "slf4j-log4j12" % slf4jVersion, + "cascading" % "cascading-hadoop" % cascadingVersion + ) + } + ).dependsOn(scaldingCore) } diff --git a/project/build.properties b/project/build.properties index be6c454fba..a6e117b610 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=0.13.5 +sbt.version=0.13.8 diff --git a/sbt b/sbt index 25cd36d65c..422327fc31 100755 --- a/sbt +++ b/sbt @@ -1,17 +1,17 @@ #!/usr/bin/env bash # # A more capable sbt runner, coincidentally also called sbt. -# Author: Paul Phillips +# Author: Paul Phillips # todo - make this dynamic -declare -r sbt_release_version="0.13.5" -declare -r sbt_unreleased_version="0.13.6-MSERVER-1" +declare -r sbt_release_version="0.13.8" +declare -r sbt_unreleased_version="0.13.8" declare -r buildProps="project/build.properties" declare sbt_jar sbt_dir sbt_create sbt_version -declare scala_version java_home sbt_explicit_version +declare scala_version sbt_explicit_version declare verbose noshare batch trace_level log_level -declare sbt_saved_stty +declare sbt_saved_stty debugUs echoerr () { echo >&2 "$@"; } vlog () { [[ -n "$verbose" ]] && echoerr "$@"; } @@ -19,7 +19,7 @@ vlog () { [[ -n "$verbose" ]] && echoerr "$@"; } # spaces are possible, e.g. sbt.version = 0.13.0 build_props_sbt () { [[ -r "$buildProps" ]] && \ - grep '^sbt\.version' "$buildProps" | tr '=' ' ' | awk '{ print $2; }' + grep '^sbt\.version' "$buildProps" | tr '=\r' ' ' | awk '{ print $2; }' } update_build_props_sbt () { @@ -101,12 +101,12 @@ init_default_option_file () { declare -r cms_opts="-XX:+CMSClassUnloadingEnabled -XX:+UseConcMarkSweepGC" declare -r jit_opts="-XX:ReservedCodeCacheSize=256m -XX:+TieredCompilation" -declare -r default_jvm_opts="-XX:MaxPermSize=384m -Xms512m -Xmx1536m -Xss2m $jit_opts $cms_opts" +declare -r default_jvm_opts_common="-Xms512m -Xmx1536m -Xss2m $jit_opts $cms_opts" declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" declare -r latest_28="2.8.2" declare -r latest_29="2.9.3" -declare -r latest_210="2.10.4" -declare -r latest_211="2.11.1" +declare -r latest_210="2.10.5" +declare -r latest_211="2.11.6" declare -r script_path="$(get_script_path "$BASH_SOURCE")" declare -r script_name="${script_path##*/}" @@ -115,7 +115,7 @@ declare -r script_name="${script_path##*/}" declare java_cmd="java" declare sbt_opts_file="$(init_default_option_file SBT_OPTS .sbtopts)" declare jvm_opts_file="$(init_default_option_file JVM_OPTS .jvmopts)" -declare sbt_launch_repo="https://private-repo.typesafe.com/typesafe/ivy-releases" +declare sbt_launch_repo="http://typesafe.artifactoryonline.com/typesafe/ivy-releases" # pull -J and -D options to give to java. declare -a residual_args @@ -126,14 +126,79 @@ declare -a sbt_commands # args to jvm/sbt via files or environment variables declare -a extra_jvm_opts extra_sbt_opts -# if set, use JAVA_HOME over java found in path -[[ -e "$JAVA_HOME/bin/java" ]] && java_cmd="$JAVA_HOME/bin/java" +addJava () { + vlog "[addJava] arg = '$1'" + java_args+=("$1") +} +addSbt () { + vlog "[addSbt] arg = '$1'" + sbt_commands+=("$1") +} +setThisBuild () { + vlog "[addBuild] args = '$@'" + local key="$1" && shift + addSbt "set $key in ThisBuild := $@" +} +addScalac () { + vlog "[addScalac] arg = '$1'" + scalac_args+=("$1") +} +addResidual () { + vlog "[residual] arg = '$1'" + residual_args+=("$1") +} +addResolver () { + addSbt "set resolvers += $1" +} +addDebugger () { + addJava "-Xdebug" + addJava "-Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" +} +setScalaVersion () { + [[ "$1" == *"-SNAPSHOT" ]] && addResolver 'Resolver.sonatypeRepo("snapshots")' + addSbt "++ $1" +} +setJavaHome () { + java_cmd="$1/bin/java" + setThisBuild javaHome "Some(file(\"$1\"))" + export JAVA_HOME="$1" + export JDK_HOME="$1" + export PATH="$JAVA_HOME/bin:$PATH" +} +setJavaHomeQuietly () { + addSbt warn + setJavaHome "$1" + addSbt info +} + +# if set, use JDK_HOME/JAVA_HOME over java found in path +if [[ -e "$JDK_HOME/lib/tools.jar" ]]; then + setJavaHomeQuietly "$JDK_HOME" +elif [[ -e "$JAVA_HOME/bin/java" ]]; then + setJavaHomeQuietly "$JAVA_HOME" +fi # directory to store sbt launchers declare sbt_launch_dir="$HOME/.sbt/launchers" [[ -d "$sbt_launch_dir" ]] || mkdir -p "$sbt_launch_dir" [[ -w "$sbt_launch_dir" ]] || sbt_launch_dir="$(mktemp -d -t sbt_extras_launchers.XXXXXX)" +java_version () { + local version=$("$java_cmd" -version 2>&1 | grep -E -e '(java|openjdk) version' | awk '{ print $3 }' | tr -d \") + vlog "Detected Java version: $version" + echo "${version:2:1}" +} + +# MaxPermSize critical on pre-8 jvms but incurs noisy warning on 8+ +default_jvm_opts () { + local v="$(java_version)" + if [[ $v -ge 8 ]]; then + echo "$default_jvm_opts_common" + else + echo "-XX:MaxPermSize=384m $default_jvm_opts_common" + fi +} + build_props_scala () { if [[ -r "$buildProps" ]]; then versionLine="$(grep '^build.scala.versions' "$buildProps")" @@ -157,9 +222,7 @@ execRunner () { vlog "" } - if [[ -n "$batch" ]]; then - exec display stack traces with a max of frames (default: -1, traces suppressed) + -debug-inc enable debugging log for the incremental compiler -no-colors disable ANSI color codes -sbt-create start sbt even if current directory contains no sbt project -sbt-dir path to global settings/plugins directory (default: ~/.sbt/) @@ -220,7 +291,9 @@ are not special. -prompt Set the sbt prompt; in expr, 's' is the State and 'e' is Extracted # sbt version (default: sbt.version from $buildProps if present, otherwise $sbt_release_version) + -sbt-force-latest force the use of the latest release of sbt: $sbt_release_version -sbt-version use the specified version of sbt (default: $sbt_release_version) + -sbt-dev use the latest pre-release version of sbt: $sbt_unreleased_version -sbt-jar use the specified jar as the sbt launcher -sbt-launch-dir directory to hold sbt launchers (default: ~/.sbt/launchers) -sbt-launch-repo repo url for downloading sbt launcher jar (default: $sbt_launch_repo) @@ -239,7 +312,7 @@ are not special. # passing options to the jvm - note it does NOT use JAVA_OPTS due to pollution # The default set is used if JVM_OPTS is unset and no -jvm-opts file is found - $default_jvm_opts + $(default_jvm_opts) JVM_OPTS environment variable holding either the jvm args directly, or the reference to a file containing jvm args if given path is prepended by '@' (e.g. '@/etc/jvmopts') Note: "@"-file is overridden by local '.jvmopts' or '-jvm-opts' argument. @@ -256,34 +329,6 @@ are not special. EOM } -addJava () { - vlog "[addJava] arg = '$1'" - java_args=( "${java_args[@]}" "$1" ) -} -addSbt () { - vlog "[addSbt] arg = '$1'" - sbt_commands=( "${sbt_commands[@]}" "$1" ) -} -addScalac () { - vlog "[addScalac] arg = '$1'" - scalac_args=( "${scalac_args[@]}" "$1" ) -} -addResidual () { - vlog "[residual] arg = '$1'" - residual_args=( "${residual_args[@]}" "$1" ) -} -addResolver () { - addSbt "set resolvers += $1" -} -addDebugger () { - addJava "-Xdebug" - addJava "-Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" -} -setScalaVersion () { - [[ "$1" == *"-SNAPSHOT" ]] && addResolver 'Resolver.sonatypeRepo("snapshots")' - addSbt "++ $1" -} - process_args () { require_arg () { @@ -297,45 +342,50 @@ process_args () } while [[ $# -gt 0 ]]; do case "$1" in - -h|-help) usage; exit 1 ;; - -v) verbose=true && shift ;; - -d) addSbt "--debug" && shift ;; - -w) addSbt "--warn" && shift ;; - -q) addSbt "--error" && shift ;; - -trace) require_arg integer "$1" "$2" && trace_level="$2" && shift 2 ;; - -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; - -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; - -no-share) noshare=true && shift ;; - -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; - -sbt-dir) require_arg path "$1" "$2" && sbt_dir="$2" && shift 2 ;; - -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; - -offline) addSbt "set offline := true" && shift ;; - -jvm-debug) require_arg port "$1" "$2" && addDebugger "$2" && shift 2 ;; - -batch) batch=true && shift ;; - -prompt) require_arg "expr" "$1" "$2" && addSbt "set shellPrompt in ThisBuild := (s => { val e = Project.extract(s) ; $2 })" && shift 2 ;; - - -sbt-create) sbt_create=true && shift ;; - -sbt-jar) require_arg path "$1" "$2" && sbt_jar="$2" && shift 2 ;; - -sbt-version) require_arg version "$1" "$2" && sbt_explicit_version="$2" && shift 2 ;; - -sbt-dev) sbt_explicit_version="$sbt_unreleased_version" && shift ;; --sbt-launch-dir) require_arg path "$1" "$2" && sbt_launch_dir="$2" && shift 2 ;; --sbt-launch-repo) require_arg path "$1" "$2" && sbt_launch_repo="$2" && shift 2 ;; - -scala-version) require_arg version "$1" "$2" && setScalaVersion "$2" && shift 2 ;; --binary-version) require_arg version "$1" "$2" && addSbt "set scalaBinaryVersion in ThisBuild := \"$2\"" && shift 2 ;; - -scala-home) require_arg path "$1" "$2" && addSbt "set every scalaHome := Some(file(\"$2\"))" && shift 2 ;; - -java-home) require_arg path "$1" "$2" && java_cmd="$2/bin/java" && shift 2 ;; - -sbt-opts) require_arg path "$1" "$2" && sbt_opts_file="$2" && shift 2 ;; - -jvm-opts) require_arg path "$1" "$2" && jvm_opts_file="$2" && shift 2 ;; - - -D*) addJava "$1" && shift ;; - -J*) addJava "${1:2}" && shift ;; - -S*) addScalac "${1:2}" && shift ;; - -28) setScalaVersion "$latest_28" && shift ;; - -29) setScalaVersion "$latest_29" && shift ;; - -210) setScalaVersion "$latest_210" && shift ;; - -211) setScalaVersion "$latest_211" && shift ;; - - *) addResidual "$1" && shift ;; + -h|-help) usage; exit 1 ;; + -v) verbose=true && shift ;; + -d) addSbt "--debug" && addSbt debug && shift ;; + -w) addSbt "--warn" && addSbt warn && shift ;; + -q) addSbt "--error" && addSbt error && shift ;; + -x) debugUs=true && shift ;; + -trace) require_arg integer "$1" "$2" && trace_level="$2" && shift 2 ;; + -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; + -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; + -no-share) noshare=true && shift ;; + -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; + -sbt-dir) require_arg path "$1" "$2" && sbt_dir="$2" && shift 2 ;; + -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; + -offline) addSbt "set offline := true" && shift ;; + -jvm-debug) require_arg port "$1" "$2" && addDebugger "$2" && shift 2 ;; + -batch) batch=true && shift ;; + -prompt) require_arg "expr" "$1" "$2" && setThisBuild shellPrompt "(s => { val e = Project.extract(s) ; $2 })" && shift 2 ;; + + -sbt-create) sbt_create=true && shift ;; + -sbt-jar) require_arg path "$1" "$2" && sbt_jar="$2" && shift 2 ;; + -sbt-version) require_arg version "$1" "$2" && sbt_explicit_version="$2" && shift 2 ;; + -sbt-force-latest) sbt_explicit_version="$sbt_release_version" && shift ;; + -sbt-dev) sbt_explicit_version="$sbt_unreleased_version" && shift ;; + -sbt-launch-dir) require_arg path "$1" "$2" && sbt_launch_dir="$2" && shift 2 ;; + -sbt-launch-repo) require_arg path "$1" "$2" && sbt_launch_repo="$2" && shift 2 ;; + -scala-version) require_arg version "$1" "$2" && setScalaVersion "$2" && shift 2 ;; + -binary-version) require_arg version "$1" "$2" && setThisBuild scalaBinaryVersion "\"$2\"" && shift 2 ;; + -scala-home) require_arg path "$1" "$2" && setThisBuild scalaHome "Some(file(\"$2\"))" && shift 2 ;; + -java-home) require_arg path "$1" "$2" && setJavaHome "$2" && shift 2 ;; + -sbt-opts) require_arg path "$1" "$2" && sbt_opts_file="$2" && shift 2 ;; + -jvm-opts) require_arg path "$1" "$2" && jvm_opts_file="$2" && shift 2 ;; + + -D*) addJava "$1" && shift ;; + -J*) addJava "${1:2}" && shift ;; + -S*) addScalac "${1:2}" && shift ;; + -28) setScalaVersion "$latest_28" && shift ;; + -29) setScalaVersion "$latest_29" && shift ;; + -210) setScalaVersion "$latest_210" && shift ;; + -211) setScalaVersion "$latest_211" && shift ;; + + --debug) addSbt debug && addResidual "$1" && shift ;; + --warn) addSbt warn && addResidual "$1" && shift ;; + --error) addSbt error && addResidual "$1" && shift ;; + *) addResidual "$1" && shift ;; esac done } @@ -375,7 +425,7 @@ set_sbt_version setTraceLevel() { case "$sbt_version" in "0.7."* | "0.10."* | "0.11."* ) echoerr "Cannot set trace level in sbt version $sbt_version" ;; - *) addSbt "set every traceLevel := $trace_level" ;; + *) setThisBuild traceLevel $trace_level ;; esac } @@ -442,16 +492,52 @@ elif [[ -n "$JVM_OPTS" && ! ("$JVM_OPTS" =~ ^@.*) ]]; then extra_jvm_opts=( $JVM_OPTS ) else vlog "Using default jvm options" - extra_jvm_opts=( $default_jvm_opts ) + extra_jvm_opts=( $(default_jvm_opts) ) fi # traceLevel is 0.12+ [[ -n "$trace_level" ]] && setTraceLevel +main () { + execRunner "$java_cmd" \ + "${extra_jvm_opts[@]}" \ + "${java_args[@]}" \ + -jar "$sbt_jar" \ + "${sbt_commands[@]}" \ + "${residual_args[@]}" +} + +# sbt inserts this string on certain lines when formatting is enabled: +# val OverwriteLine = "\r\u001BM\u001B[2K" +# ...in order not to spam the console with a million "Resolving" lines. +# Unfortunately that makes it that much harder to work with when +# we're not going to print those lines anyway. We strip that bit of +# line noise, but leave the other codes to preserve color. +mainFiltered () { + local ansiOverwrite='\r\x1BM\x1B[2K' + local excludeRegex=$(egrep -v '^#|^$' ~/.sbtignore | paste -sd'|' -) + + echoLine () { + local line="$1" + local line1="$(echo "$line" | sed -r 's/\r\x1BM\x1B\[2K//g')" # This strips the OverwriteLine code. + local line2="$(echo "$line1" | sed -r 's/\x1B\[[0-9;]*[JKmsu]//g')" # This strips all codes - we test regexes against this. + + if [[ $line2 =~ $excludeRegex ]]; then + [[ -n $debugUs ]] && echo "[X] $line1" + else + [[ -n $debugUs ]] && echo " $line1" || echo "$line1" + fi + } + + echoLine "Starting sbt with output filtering enabled." + main | while read -r line; do echoLine "$line"; done +} + +# Only filter if there's a filter file and we don't see a known interactive command. +# Obviously this is super ad hoc but I don't know how to improve on it. Testing whether +# stdin is a terminal is useless because most of my use cases for this filtering are +# exactly when I'm at a terminal, running sbt non-interactively. +shouldFilter () { [[ -f ~/.sbtignore ]] && ! egrep -q '\b(shell|console|consoleProject)\b' <<<"${residual_args[@]}"; } + # run sbt -execRunner "$java_cmd" \ - "${extra_jvm_opts[@]}" \ - "${java_args[@]}" \ - -jar "$sbt_jar" \ - "${sbt_commands[@]}" \ - "${residual_args[@]}" \ No newline at end of file +if shouldFilter; then mainFiltered; else main; fi \ No newline at end of file diff --git a/scalding-args/src/main/scala/com/twitter/scalding/Args.scala b/scalding-args/src/main/scala/com/twitter/scalding/Args.scala index 98f09c941b..6d71596b2a 100644 --- a/scalding-args/src/main/scala/com/twitter/scalding/Args.scala +++ b/scalding-args/src/main/scala/com/twitter/scalding/Args.scala @@ -112,6 +112,8 @@ class Args(val m: Map[String, List[String]]) extends java.io.Serializable { } } + override def hashCode(): Int = m.hashCode() + /** * Equivalent to .optional(key).getOrElse(default) */ diff --git a/scalding-benchmarks/src/test/scala/com/twitter/scalding/Serialization.scala b/scalding-benchmarks/src/test/scala/com/twitter/scalding/Serialization.scala new file mode 100644 index 0000000000..de4ef75323 --- /dev/null +++ b/scalding-benchmarks/src/test/scala/com/twitter/scalding/Serialization.scala @@ -0,0 +1,306 @@ +package com.twitter.scalding.benchmarks + +import com.twitter.chill.KryoPool +import com.twitter.scalding.serialization._ +import java.io.ByteArrayInputStream +import org.scalacheck.{ Gen => scGen, Arbitrary } // We use scalacheck Gens to generate random scalameter gens. +import org.scalameter.api._ +import scala.collection.generic.CanBuildFrom +import scala.language.experimental.macros + +trait LowerPriorityImplicit { + implicit def ordBuf[T]: OrderedSerialization[T] = macro com.twitter.scalding.macros.impl.OrderedSerializationProviderImpl[T] +} + +object SerializationBenchmark extends PerformanceTest.Quickbenchmark with LowerPriorityImplicit { + import JavaStreamEnrichments._ + + val sizes = Gen.range("size")(300000, 1500000, 300000) + val smallSizes = Gen.range("size")(30000, 150000, 30000) + + /** + * This tends to create ascii strings + */ + def asciiStringGen: scGen[String] = scGen.parameterized { p => + val thisSize = p.rng.nextInt(p.size + 1) + scGen.const(new String(Array.fill(thisSize)(p.rng.nextInt(128).toByte))) + } + def charStringGen: scGen[String] = + scGen.listOf(scGen.choose(0.toChar, Char.MaxValue)).map(_.mkString) + + // Biases to ascii 80% of the time + def stringGen: scGen[String] = scGen.frequency((4, asciiStringGen), (1, charStringGen)) + + implicit def stringArb: Arbitrary[String] = Arbitrary(stringGen) + + def collection[T, C[_]](size: Gen[Int])(implicit arbT: Arbitrary[T], cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = + collection[T, C](size, arbT.arbitrary)(cbf) + + def collection[T, C[_]](size: Gen[Int], item: scGen[T])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = + size.map { s => + val builder = cbf() + builder.sizeHint(s) + // Initialize a fixed random number generator + val rng = new scala.util.Random("scalding".hashCode) + val p = scGen.Parameters.default.withRng(rng) + + def get(attempt: Int): T = + if (attempt > 1000) sys.error("Failed to generate after 100 tries") + else { + item(p) match { + case None => get(attempt + 1) + case Some(t) => t + } + } + + (0 until s).foreach { _ => + builder += get(0) + } + builder.result() + } + + def roundTrip[T: Serialization](ts: Iterator[T]): Unit = + ts.map { t => + Serialization.fromBytes(Serialization.toBytes(t)).get + }.foreach(_ => ()) + + def kryoRoundTrip[T](k: KryoPool, ts: Iterator[T]): Unit = + ts.map { t => k.fromBytes(k.toBytesWithClass(t)) } + .foreach(_ => ()) + + def toArrayOrd[T](t: OrderedSerialization[T]): Ordering[Array[Byte]] = new Ordering[Array[Byte]] { + def compare(a: Array[Byte], b: Array[Byte]) = { + t.compareBinary(new ByteArrayInputStream(a), new ByteArrayInputStream(b)).unsafeToInt + } + } + def toArrayOrd[T](k: KryoPool, ord: Ordering[T]): Ordering[Array[Byte]] = new Ordering[Array[Byte]] { + def compare(a: Array[Byte], b: Array[Byte]) = + ord.compare(k.fromBytes(a).asInstanceOf[T], + k.fromBytes(b).asInstanceOf[T]) + } + + val longArrayByte: Gen[Array[Byte]] = + collection[Byte, Array](sizes.map(s => (s / 8) * 8)) + + // This is here to make sure the compiler cannot optimize away reads + var effectInt: Int = 0 + var effectLong: Long = 0L + + performance of "Serialization" in { + measure method "JavaStreamEnrichments.readInt" in { + using(longArrayByte) in { a => + val length = a.length + val is = new ByteArrayInputStream(a) + var ints = length / 4 + while (ints > 0) { + effectInt ^= is.readInt + ints -= 1 + } + } + } + measure method "JavaStreamEnrichments.readLong" in { + using(longArrayByte) in { a => + val length = a.length + val is = new ByteArrayInputStream(a) + var longs = length / 8 + while (longs > 0) { + effectLong ^= is.readLong + longs -= 1 + } + } + } + measure method "UnsignedComparisons.unsignedLongCompare" in { + using(collection[Long, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= UnsignedComparisons.unsignedLongCompare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "normal long compare" in { + using(collection[Long, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= java.lang.Long.compare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "UnsignedComparisons.unsignedInt" in { + using(collection[Int, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= UnsignedComparisons.unsignedIntCompare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "normal int compare" in { + using(collection[Int, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= java.lang.Integer.compare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "UnsignedComparisons.unsignedShort" in { + using(collection[Short, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= UnsignedComparisons.unsignedShortCompare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "normal short compare" in { + using(collection[Short, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= java.lang.Short.compare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "UnsignedComparisons.unsignedByte" in { + using(collection[Byte, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= UnsignedComparisons.unsignedByteCompare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "normal byte compare" in { + using(collection[Byte, Array](sizes)) in { a => + val max = a.length - 1 + var pos = 0 + while (pos < max) { + effectInt ^= java.lang.Byte.compare(a(pos), a(pos + 1)) + pos += 2 + } + } + } + measure method "typeclass: Int" in { + using(collection[Int, List](sizes)) in { l => roundTrip(l.iterator) } + } + measure method "kryo: Int" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + using(collection[Int, List](sizes)) in { l => kryoRoundTrip(kryo, l.iterator) } + } + measure method "typeclass: String" in { + using(collection[String, List](smallSizes)) in { l => roundTrip(l.iterator) } + } + measure method "kryo: String" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + using(collection[String, List](smallSizes)) in { l => kryoRoundTrip(kryo, l.iterator) } + } + measure method "typeclass: (Int, (Long, String))" in { + using(collection[(Int, (Long, String)), List](smallSizes)) in { l => roundTrip(l.iterator) } + } + measure method "kryo: (Int, (Long, String))" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + using(collection[(Int, (Long, String)), List](smallSizes)) in { l => kryoRoundTrip(kryo, l.iterator) } + } + measure method "typeclass: (Int, Long, Short)" in { + using(collection[(Int, Long, Short), List](smallSizes)) in { l => roundTrip(l.iterator) } + } + measure method "kryo: (Int, Long, Short)" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + using(collection[(Int, Long, Short), List](smallSizes)) in { l => kryoRoundTrip(kryo, l.iterator) } + } + measure method "sort typeclass: Int" in { + val ordSer = implicitly[OrderedSerialization[Int]] + using(collection[Int, List](smallSizes) + .map { items => + items.map { Serialization.toBytes(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(ordSer)) } + } + measure method "sort kryo: Int" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + val ord = implicitly[Ordering[Int]] + using(collection[Int, List](smallSizes) + .map { items => + items.map { kryo.toBytesWithClass(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(kryo, ord)) } + } + measure method "sort typeclass: Long" in { + val ordSer = implicitly[OrderedSerialization[Long]] + using(collection[Long, List](smallSizes) + .map { items => + items.map { Serialization.toBytes(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(ordSer)) } + } + measure method "sort kryo: Long" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + val ord = implicitly[Ordering[Long]] + using(collection[Long, List](smallSizes) + .map { items => + items.map { kryo.toBytesWithClass(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(kryo, ord)) } + } + measure method "sort typeclass: String" in { + val ordSer = implicitly[OrderedSerialization[String]] + using(collection[String, List](smallSizes) + .map { items => + items.map { Serialization.toBytes(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(ordSer)) } + } + measure method "sort kryo: String" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + val ord = implicitly[Ordering[String]] + using(collection[String, List](smallSizes) + .map { items => + items.map { kryo.toBytesWithClass(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(kryo, ord)) } + } + + measure method "sort typeclass: (Int, (Long, String))" in { + val ordSer = implicitly[OrderedSerialization[(Int, (Long, String))]] + using(collection[(Int, (Long, String)), List](smallSizes) + .map { items => + items.map { Serialization.toBytes(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(ordSer)) } + } + measure method "sort kryo: (Int, (Long, String))" in { + val kryo = KryoPool.withByteArrayOutputStream(1, + com.twitter.scalding.Config.default.getKryo.get) + + val ord = implicitly[Ordering[(Int, (Long, String))]] + using(collection[(Int, (Long, String)), List](smallSizes) + .map { items => + items.map { kryo.toBytesWithClass(_) }.toArray + }) in { ary => java.util.Arrays.sort(ary, toArrayOrd(kryo, ord)) } + } + + /** + * TODO: + * 1) simple case class + * 2) case class with some nesting and collections + * 3) sorting of an Array[Array[Byte]] using OrderedSerialization vs Array[T] + * 4) fastest binary sorting for strings (byte-by-byte, longs, etc...) + */ + } +} diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/extensions/Checkpoint.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/extensions/Checkpoint.scala index d27c85be57..a6d7666941 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/extensions/Checkpoint.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/extensions/Checkpoint.scala @@ -19,12 +19,9 @@ package com.twitter.scalding.commons.extensions import com.twitter.scalding._ import com.twitter.scalding.Dsl._ -import java.io.File import cascading.flow.FlowDef import cascading.pipe.Pipe -import cascading.tuple.{ Fields, TupleEntry } -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{ FileSystem, Path } +import cascading.tuple.Fields import org.slf4j.{ Logger, LoggerFactory => LogManager } /** diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/BinaryConverters.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/BinaryConverters.scala new file mode 100644 index 0000000000..24beaa7e9f --- /dev/null +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/BinaryConverters.scala @@ -0,0 +1,56 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.commons.source + +import com.twitter.elephantbird.mapreduce.io.BinaryConverter +import com.twitter.scrooge.{ BinaryThriftStructSerializer, ThriftStructCodec, ThriftStruct } +import scala.reflect.ClassTag +import scala.util.Try + +/* + * Common BinaryConverters to be used with GenericSource / GenericScheme. + */ + +case object IdentityBinaryConverter extends BinaryConverter[Array[Byte]] { + override def fromBytes(messageBuffer: Array[Byte]) = messageBuffer + override def toBytes(message: Array[Byte]) = message +} + +object ScroogeBinaryConverter { + + // codec code borrowed from chill's ScroogeThriftStructSerializer class + private[this] def codecForNormal[T <: ThriftStruct](thriftStructClass: Class[T]): Try[ThriftStructCodec[T]] = + Try(Class.forName(thriftStructClass.getName + "$").getField("MODULE$").get(null)) + .map(_.asInstanceOf[ThriftStructCodec[T]]) + + private[this] def codecForUnion[T <: ThriftStruct](maybeUnion: Class[T]): Try[ThriftStructCodec[T]] = + Try(Class.forName(maybeUnion.getName.reverse.dropWhile(_ != '$').reverse).getField("MODULE$").get(null)) + .map(_.asInstanceOf[ThriftStructCodec[T]]) + + def apply[T <: ThriftStruct: ClassTag]: BinaryConverter[T] = { + val ct = implicitly[ClassTag[T]] + new BinaryConverter[T] { + val serializer = BinaryThriftStructSerializer[T] { + val clazz = ct.runtimeClass.asInstanceOf[Class[T]] + codecForNormal[T](clazz).orElse(codecForUnion[T](clazz)).get + } + override def toBytes(struct: T) = serializer.toBytes(struct) + override def fromBytes(bytes: Array[Byte]): T = serializer.fromBytes(bytes) + } + } +} + diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala index c7bee55b4e..a33cb85a55 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala @@ -19,8 +19,6 @@ package com.twitter.scalding.commons.source import com.google.protobuf.Message import com.twitter.bijection.Injection import com.twitter.chill.Externalizer -import com.twitter.elephantbird.cascading2.scheme._ -import com.twitter.elephantbird.util.{ ThriftUtils, TypeRef } import com.twitter.scalding._ import com.twitter.scalding.source._ diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/FixedPathSources.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/FixedPathSources.scala index a7aad7ee18..eaca1d9863 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/FixedPathSources.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/FixedPathSources.scala @@ -18,8 +18,6 @@ package com.twitter.scalding.commons.source import com.google.protobuf.Message import com.twitter.scalding._ -import com.twitter.scalding.Dsl._ -import java.io.Serializable import org.apache.thrift.TBase abstract class FixedPathLzoThrift[T <: TBase[_, _]: Manifest](path: String*) diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/HourlySources.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/HourlySources.scala index 64d29df656..568dce0609 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/HourlySources.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/HourlySources.scala @@ -21,9 +21,7 @@ import com.google.protobuf.Message import com.twitter.bijection.Injection import com.twitter.chill.Externalizer import com.twitter.scalding._ -import com.twitter.scalding.Dsl._ import com.twitter.scalding.source._ -import java.io.Serializable import org.apache.thrift.TBase abstract class HourlySuffixLzoCodec[T](prefix: String, dateRange: DateRange)(implicit @transient suppliedInjection: Injection[T, Array[Byte]]) diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LongThriftTransformer.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LongThriftTransformer.scala index 80d25fe45f..dda71a26a4 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LongThriftTransformer.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LongThriftTransformer.scala @@ -21,7 +21,6 @@ import cascading.tuple.Fields import com.twitter.elephantbird.mapreduce.io.ThriftWritable import com.twitter.elephantbird.util.{ ThriftUtils, TypeRef } import com.twitter.scalding._ -import com.twitter.scalding.Dsl._ import org.apache.hadoop.io.{ LongWritable, Writable } import org.apache.thrift.TBase diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoCodecSource.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoCodecSource.scala index c97a612644..f7b1e9b0c9 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoCodecSource.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoCodecSource.scala @@ -17,7 +17,6 @@ limitations under the License. package com.twitter.scalding.commons.source import com.twitter.chill.Externalizer -import com.twitter.scalding._ import com.twitter.bijection.Injection /** @@ -29,7 +28,7 @@ object LzoCodecSource { def apply[T](paths: String*)(implicit passedInjection: Injection[T, Array[Byte]]) = new LzoCodec[T] { val hdfsPaths = paths - val localPath = { assert(paths.size == 1, "Cannot use multiple input files on local mode"); paths(0) } + val localPaths = paths val boxed = Externalizer(passedInjection) override def injection = boxed.get } diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericScheme.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericScheme.scala new file mode 100644 index 0000000000..aad3b696c3 --- /dev/null +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericScheme.scala @@ -0,0 +1,136 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.commons.source + +import scala.reflect.ClassTag + +import com.twitter.bijection._ +import com.twitter.chill.Externalizer +import com.twitter.elephantbird.cascading2.scheme.LzoBinaryScheme +import com.twitter.elephantbird.mapreduce.input.combine.DelegateCombineFileInputFormat +import com.twitter.elephantbird.mapreduce.io.{ BinaryConverter, GenericWritable } +import com.twitter.elephantbird.mapreduce.input.{ BinaryConverterProvider, MultiInputFormat } +import com.twitter.elephantbird.mapreduce.output.LzoGenericBlockOutputFormat +import com.twitter.elephantbird.mapred.output.DeprecatedOutputFormatWrapper + +import org.apache.hadoop.mapred.{ JobConf, OutputCollector, RecordReader } +import org.apache.hadoop.conf.Configuration + +import cascading.tap.Tap +import cascading.flow.FlowProcess + +/** + * Serializes BinaryConverters to JobConf. + */ +private[source] object ExternalizerSerializer { + def inj[T]: Injection[Externalizer[T], String] = { + import com.twitter.bijection.Inversion.attemptWhen + import com.twitter.bijection.codec.Base64 + + implicit val baseInj = JavaSerializationInjection[Externalizer[T]] + + implicit val unwrap: Injection[GZippedBase64String, String] = + // this does not catch cases where it's Base64 but not compressed + // but the decompression injection will, so it's safe to do this + new AbstractInjection[GZippedBase64String, String] { + override def apply(gzbs: GZippedBase64String) = gzbs.str + override def invert(str: String) = attemptWhen(str)(Base64.isBase64)(GZippedBase64String(_)) + } + + Injection.connect[Externalizer[T], Array[Byte], GZippedBase64String, String] + } +} + +private[source] object ConfigBinaryConverterProvider { + val ProviderConfKey = "com.twitter.scalding.lzo.converter.provider" +} + +/** + * Provides BinaryConverter serialized in JobConf. + */ +private[source] class ConfigBinaryConverterProvider[M] extends BinaryConverterProvider[M] { + import ConfigBinaryConverterProvider._ + + private[this] var cached: Option[(String, BinaryConverter[M])] = None + + override def getConverter(conf: Configuration): BinaryConverter[M] = { + val data = conf.get(ProviderConfKey) + require(data != null, s"$ProviderConfKey is not set in configuration") + cached match { + case Some((d, conv)) if d == data => conv + case _ => + val extern = ExternalizerSerializer.inj.invert(data).get + val conv = extern.get.asInstanceOf[BinaryConverter[M]] + cached = Some((data, conv)) + conv + } + } +} + +object LzoGenericScheme { + def apply[M: ClassTag](conv: BinaryConverter[M]): LzoGenericScheme[M] = + new LzoGenericScheme(conv, implicitly[ClassTag[M]].runtimeClass.asInstanceOf[Class[M]]) + + def apply[M](conv: BinaryConverter[M], clazz: Class[M]): LzoGenericScheme[M] = + new LzoGenericScheme(conv, clazz) +} + +/** + * Generic scheme for data stored as lzo-compressed protobuf messages. + * Serialization is performed using the supplied BinaryConverter. + */ +class LzoGenericScheme[M](@transient conv: BinaryConverter[M], clazz: Class[M]) extends LzoBinaryScheme[M, GenericWritable[M]] { + + override protected def prepareBinaryWritable(): GenericWritable[M] = + new GenericWritable(conv) + + override def sourceConfInit(fp: FlowProcess[JobConf], + tap: Tap[JobConf, RecordReader[_, _], OutputCollector[_, _]], + conf: JobConf): Unit = { + + val extern = Externalizer(conv) + try { + ExternalizerSerializer.inj.invert(ExternalizerSerializer.inj(extern)).get + } catch { + case e: Exception => throw new RuntimeException("Unable to roundtrip the BinaryConverter in the Externalizer.", e) + } + + conf.set(ConfigBinaryConverterProvider.ProviderConfKey, ExternalizerSerializer.inj(extern)) + + MultiInputFormat.setClassConf(clazz, conf) + MultiInputFormat.setGenericConverterClassConf(classOf[ConfigBinaryConverterProvider[_]], conf) + + DelegateCombineFileInputFormat.setDelegateInputFormat(conf, classOf[MultiInputFormat[_]]) + } + + override def sinkConfInit(fp: FlowProcess[JobConf], + tap: Tap[JobConf, RecordReader[_, _], OutputCollector[_, _]], + conf: JobConf): Unit = { + val extern = Externalizer(conv) + try { + ExternalizerSerializer.inj.invert(ExternalizerSerializer.inj(extern)).get + } catch { + case e: Exception => throw new RuntimeException("Unable to roundtrip the BinaryConverter in the Externalizer.", e) + } + + LzoGenericBlockOutputFormat.setClassConf(clazz, conf) + conf.set(ConfigBinaryConverterProvider.ProviderConfKey, ExternalizerSerializer.inj(extern)) + LzoGenericBlockOutputFormat.setGenericConverterClassConf(classOf[ConfigBinaryConverterProvider[_]], conf) + DeprecatedOutputFormatWrapper.setOutputFormat(classOf[LzoGenericBlockOutputFormat[_]], conf) + } +} + diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericSource.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericSource.scala new file mode 100644 index 0000000000..c75448b90e --- /dev/null +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoGenericSource.scala @@ -0,0 +1,44 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.commons.source + +import scala.reflect.ClassTag + +import com.twitter.elephantbird.mapreduce.io.BinaryConverter +import com.twitter.scalding._ + +import cascading.scheme.Scheme + +/** + * Generic source with an underlying GenericScheme that uses the supplied BinaryConverter. + */ +abstract class LzoGenericSource[T] extends FileSource with SingleMappable[T] with TypedSink[T] with LocalTapSource { + def clazz: Class[T] + def conv: BinaryConverter[T] + override def setter[U <: T] = TupleSetter.asSubSetter[T, U](TupleSetter.singleSetter[T]) + override def hdfsScheme = HadoopSchemeInstance(LzoGenericScheme[T](conv, clazz).asInstanceOf[Scheme[_, _, _, _, _]]) +} + +object LzoGenericSource { + def apply[T](passedConv: BinaryConverter[T], passedClass: Class[T], paths: String*) = + new LzoGenericSource[T] { + override val conv: BinaryConverter[T] = passedConv + override val clazz = passedClass + override val hdfsPaths = paths + override val localPaths = paths + } +} diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoTraits.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoTraits.scala index d69cfcb1ac..eeb28fc929 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoTraits.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/LzoTraits.scala @@ -16,10 +16,7 @@ limitations under the License. package com.twitter.scalding.commons.source -import collection.mutable.ListBuffer - import cascading.pipe.Pipe -import cascading.scheme.local.{ TextDelimited => CLTextDelimited, TextLine => CLTextLine } import cascading.scheme.Scheme import org.apache.thrift.TBase diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/PailSource.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/PailSource.scala index 3f700b3573..7b60eedc77 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/PailSource.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/PailSource.scala @@ -19,15 +19,11 @@ package com.twitter.scalding.commons.source import scala.reflect.ClassTag import com.backtype.cascading.tap.PailTap -import com.backtype.hadoop.pail.{ Pail, PailStructure } -import cascading.pipe.Pipe -import cascading.scheme.Scheme +import com.backtype.hadoop.pail.PailStructure import cascading.tap.Tap import com.twitter.bijection.Injection -import com.twitter.chill.Externalizer import com.twitter.scalding._ import java.util.{ List => JList } -import org.apache.hadoop.mapred.{ JobConf, OutputCollector, RecordReader } import scala.collection.JavaConverters._ /** diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/TsvWithHeader.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/TsvWithHeader.scala index 46dd53744a..7d369fd4aa 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/TsvWithHeader.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/TsvWithHeader.scala @@ -22,9 +22,7 @@ import cascading.tuple.Fields import com.google.common.base.Charsets import com.google.common.io.Files import com.twitter.scalding._ -import com.twitter.scalding.Dsl._ -import java.io.{ BufferedWriter, File, FileOutputStream, IOException, OutputStreamWriter, Serializable } -import org.apache.hadoop.conf.Configuration +import java.io.{ BufferedWriter, File, FileOutputStream, IOException, OutputStreamWriter } import org.apache.hadoop.fs.{ FileSystem, Path } /** diff --git a/scalding-commons/src/test/scala/com/twitter/scalding/commons/VersionedKeyValSourceTest.scala b/scalding-commons/src/test/scala/com/twitter/scalding/commons/VersionedKeyValSourceTest.scala index 5b51fb4f72..5a41d19bdf 100644 --- a/scalding-commons/src/test/scala/com/twitter/scalding/commons/VersionedKeyValSourceTest.scala +++ b/scalding-commons/src/test/scala/com/twitter/scalding/commons/VersionedKeyValSourceTest.scala @@ -23,11 +23,8 @@ import com.backtype.hadoop.datastores.VersionedStore import org.apache.hadoop.mapred.JobConf // Use the scalacheck generators -import org.scalacheck.Gen import scala.collection.mutable.Buffer -import TDsl._ - class TypedWriteIncrementalJob(args: Args) extends Job(args) { import RichPipeEx._ val pipe = TypedPipe.from(TypedTsv[Int]("input")) diff --git a/scalding-commons/src/test/scala/com/twitter/scalding/commons/source/LzoGenericSourceSpec.scala b/scalding-commons/src/test/scala/com/twitter/scalding/commons/source/LzoGenericSourceSpec.scala new file mode 100644 index 0000000000..6b87a9edc5 --- /dev/null +++ b/scalding-commons/src/test/scala/com/twitter/scalding/commons/source/LzoGenericSourceSpec.scala @@ -0,0 +1,30 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.commons.source + +import com.twitter.bijection.JavaSerializationInjection +import org.scalatest.{ Matchers, WordSpec } +import scala.util.Success + +class LzoGenericSourceSpec extends WordSpec with Matchers { + "LzoGenericScheme" should { + "be serializable" in { + val scheme = LzoGenericScheme[Array[Byte]](IdentityBinaryConverter) + val inj = JavaSerializationInjection[LzoGenericScheme[Array[Byte]]] + inj.invert(inj.apply(scheme)) shouldBe Success(scheme) + } + } +} diff --git a/scalding-core/src/main/scala/com/twitter/package.scala b/scalding-core/src/main/scala/com/twitter/package.scala index f712d2d353..a9b0009608 100644 --- a/scalding-core/src/main/scala/com/twitter/package.scala +++ b/scalding-core/src/main/scala/com/twitter/package.scala @@ -34,7 +34,7 @@ package object scalding { /** * Make sure this is in sync with version.sbt */ - val scaldingVersion: String = "0.13.1" + val scaldingVersion: String = "0.15.0" object RichPathFilter { implicit def toRichPathFilter(f: PathFilter) = new RichPathFilter(f) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/CascadingTokenUpdater.scala b/scalding-core/src/main/scala/com/twitter/scalding/CascadingTokenUpdater.scala new file mode 100644 index 0000000000..f8f0a5da82 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/CascadingTokenUpdater.scala @@ -0,0 +1,64 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding + +object CascadingTokenUpdater { + private final val lowestAllowed = 128 // cascading rules + + // Take a cascading string of tokens and turns it into a map + // from token index to class + def parseTokens(tokClass: String): Map[Int, String] = + if (tokClass == null || tokClass.isEmpty) + Map[Int, String]() + else + tokClass + .split(",") + .map(_.trim) + .filter(_.size > 1) + .toIterator + .map(_.split("=")) + .filter(_.size == 2) + .map { ary => (ary(0).toInt, ary(1)) } + .toMap + + // does the inverse of the previous function, given a Map of index to class + // return the cascading token format for it + private def toksToString(m: Map[Int, String]): String = + m.map { case (tok, clazz) => s"$tok=$clazz" }.mkString(",") + + // Given the map of already assigned tokens, what is the next available one + private def firstAvailableToken(m: Map[Int, String]): Int = + if (m.isEmpty) lowestAllowed + else scala.math.max(m.keys.max + 1, lowestAllowed) + + // Given the first free token spot + // assign each of the class names given to al the subsequent + // positions + private def assignTokens(first: Int, names: Iterable[String]): Map[Int, String] = + names.foldLeft((first, Map[Int, String]())) { (idMap, clz) => + val (id, m) = idMap + (id + 1, m + (id -> clz)) + }._2 + + def update(config: Config, clazzes: Set[Class[_]]): Config = { + val toks = config.getCascadingSerializationTokens + // We don't want to assign tokens to classes already in the map + val newClasses: Iterable[String] = clazzes.map { _.getName } -- toks.values + + config + (Config.CascadingSerializationTokens -> toksToString(toks ++ assignTokens(firstAvailableToken(toks), newClasses))) + } + +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/CoGroupBuilder.scala b/scalding-core/src/main/scala/com/twitter/scalding/CoGroupBuilder.scala index 0610596d0d..aa8ee5d786 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/CoGroupBuilder.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/CoGroupBuilder.scala @@ -44,7 +44,7 @@ class CoGroupBuilder(groupFields: Fields, joinMode: JoinMode) extends GroupBuild val pipes = (pipe :: coGroups.map{ _._2 }).map{ RichPipe.assignName(_) }.toArray val joinModes = (joinMode :: coGroups.map{ _._3 }).map{ _.booleanValue }.toArray val mixedJoiner = new MixedJoin(joinModes) - val cg: Pipe = new CoGroup(pipes, fields, null, mixedJoiner) + val cg: Pipe = new CoGroup(pipes, fields, null, WrappedJoiner(mixedJoiner)) overrideReducers(cg) evs.foldRight(cg)((op: Pipe => Every, p) => op(p)) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Config.scala b/scalding-core/src/main/scala/com/twitter/scalding/Config.scala index 1a286d5085..8fc18166b3 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Config.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Config.scala @@ -18,19 +18,14 @@ package com.twitter.scalding import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.serializer.{ Serialization => HSerialization } import com.twitter.chill.KryoInstantiator -import com.twitter.chill.config.{ ScalaMapConfig, ScalaAnyRefMapConfig, ConfiguredInstantiator } -import com.twitter.scalding.reducer_estimation.ReducerEstimator +import com.twitter.chill.config.{ ScalaMapConfig, ConfiguredInstantiator } import cascading.pipe.assembly.AggregateBy -import cascading.flow.{ FlowStepStrategy, FlowProps } +import cascading.flow.FlowProps import cascading.property.AppProps import cascading.tuple.collect.SpillableProps import java.security.MessageDigest -import java.util.UUID - -import org.apache.hadoop.mapred.JobConf -import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ import scala.util.{ Failure, Success, Try } @@ -110,6 +105,44 @@ trait Config { def setMapSideAggregationThreshold(count: Int): Config = this + (AggregateBy.AGGREGATE_BY_THRESHOLD -> count.toString) + /** + * Set this configuration option to require all grouping/cogrouping + * to use OrderedSerialization + */ + def setRequireOrderedSerialization(b: Boolean): Config = + this + (ScaldingRequireOrderedSerialization -> (b.toString)) + + def getRequireOrderedSerialization: Boolean = + get(ScaldingRequireOrderedSerialization) + .map(_.toBoolean) + .getOrElse(false) + + def getCascadingSerializationTokens: Map[Int, String] = + get(Config.CascadingSerializationTokens) + .map(CascadingTokenUpdater.parseTokens) + .getOrElse(Map.empty[Int, String]) + + /** + * This function gets the set of classes that have been registered to Kryo. + * They may or may not be used in this job, but Cascading might want to be made aware + * that these classes exist + */ + def getKryoRegisteredClasses: Set[Class[_]] = { + // Get an instance of the Kryo serializer (which is populated with registrations) + getKryo.map { kryo => + val cr = kryo.newKryo.getClassResolver + + @annotation.tailrec + def kryoClasses(idx: Int, acc: Set[Class[_]]): Set[Class[_]] = + Option(cr.getRegistration(idx)) match { + case Some(reg) => kryoClasses(idx + 1, acc + reg.getType) + case None => acc // The first null is the end of the line + } + + kryoClasses(0, Set[Class[_]]()) + }.getOrElse(Set()) + } + /* * Hadoop and Cascading serialization needs to be first, and the Kryo serialization * needs to be last and this method handles this for you: @@ -127,7 +160,8 @@ trait Config { // Hadoop and Cascading should come first val first: Seq[Class[_ <: HSerialization[_]]] = Seq(classOf[org.apache.hadoop.io.serializer.WritableSerialization], - classOf[cascading.tuple.hadoop.TupleSerialization]) + classOf[cascading.tuple.hadoop.TupleSerialization], + classOf[serialization.WrappedSerialization[_]]) // this must come last val last: Seq[Class[_ <: HSerialization[_]]] = Seq(classOf[com.twitter.chill.hadoop.KryoSerialization]) val required = (first ++ last).toSet[AnyRef] // Class is invariant, but we use it as a function @@ -142,7 +176,13 @@ trait Config { case Left((bootstrap, inst)) => ConfiguredInstantiator.setSerialized(chillConf, bootstrap, inst) case Right(refl) => ConfiguredInstantiator.setReflect(chillConf, refl) } - Config(chillConf.toMap + hadoopKV) + val withKryo = Config(chillConf.toMap + hadoopKV) + + val kryoClasses = withKryo.getKryoRegisteredClasses + .filterNot(_.isPrimitive) // Cascading handles primitives and arrays + .filterNot(_.isArray) + + withKryo.addCascadingClassSerializationTokens(kryoClasses) } /* @@ -169,8 +209,19 @@ trait Config { // This is setting a property for cascading/driven (AppProps.APP_FRAMEWORKS -> ("scalding:" + scaldingVersion.toString))) - def getUniqueId: Option[UniqueID] = - get(UniqueID.UNIQUE_JOB_ID).map(UniqueID(_)) + def getUniqueIds: Set[UniqueID] = + get(UniqueID.UNIQUE_JOB_ID) + .map { str => str.split(",").toSet[String].map(UniqueID(_)) } + .getOrElse(Set.empty) + + /** + * The serialization of your data will be smaller if any classes passed between tasks in your job + * are listed here. Without this, strings are used to write the types IN EACH RECORD, which + * compression probably takes care of, but compression acts AFTER the data is serialized into + * buffers and spilling has been triggered. + */ + def addCascadingClassSerializationTokens(clazzes: Set[Class[_]]): Config = + CascadingTokenUpdater.update(this, clazzes) /* * This is *required* if you are using counters. You must use @@ -179,7 +230,7 @@ trait Config { def addUniqueId(u: UniqueID): Config = update(UniqueID.UNIQUE_JOB_ID) { case None => (Some(u.get), ()) - case Some(str) => (Some((str.split(",").toSet + u.get).mkString(",")), ()) + case Some(str) => (Some((StringUtility.fastSplit(str, ",").toSet + u.get).mkString(",")), ()) }._2 /** @@ -191,7 +242,7 @@ trait Config { val uid = UniqueID.getRandom (Some(uid.get), uid) case s @ Some(str) => - (s, UniqueID(str.split(",").head)) + (s, UniqueID(StringUtility.fastSplit(str, ",").head)) } /* @@ -226,8 +277,8 @@ trait Config { */ def addReducerEstimator(clsName: String): Config = update(Config.ReducerEstimators) { - case None => Some(clsName) -> () - case Some(lst) => Some(clsName + "," + lst) -> () + case None => (Some(clsName), ()) + case Some(lst) => (Some(s"$clsName,$lst"), ()) }._2 /** Set the entire list of reducer estimators (overriding the existing list) */ @@ -252,6 +303,7 @@ trait Config { object Config { val CascadingAppName: String = "cascading.app.name" val CascadingAppId: String = "cascading.app.id" + val CascadingSerializationTokens = "cascading.serialization.tokens" val IoSerializationsKey: String = "io.serializations" val ScaldingFlowClassName: String = "scalding.flow.class.name" val ScaldingFlowClassSignature: String = "scalding.flow.class.signature" @@ -259,6 +311,7 @@ object Config { val ScaldingJobArgs: String = "scalding.job.args" val ScaldingVersion: String = "scalding.version" val HRavenHistoryUserName: String = "hraven.history.user.name" + val ScaldingRequireOrderedSerialization: String = "scalding.require.orderedserialization" /** * Parameter that actually controls the number of reduce tasks. @@ -272,6 +325,9 @@ object Config { /** Whether estimator should override manually-specified reducers. */ val ReducerEstimatorOverride = "scalding.reducer.estimator.override" + /** Whether the number of reducers has been set explicitly using a `withReducers` */ + val WithReducersSetExplicitly = "scalding.with.reducers.set.explicitly" + val empty: Config = Config(Map.empty) /* diff --git a/scalding-core/src/main/scala/com/twitter/scalding/CumulativeSum.scala b/scalding-core/src/main/scala/com/twitter/scalding/CumulativeSum.scala index f0555b650e..2513854721 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/CumulativeSum.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/CumulativeSum.scala @@ -37,11 +37,13 @@ object CumulativeSum { val pipe: TypedPipe[(K, (U, V))]) { /** Takes a sortable field and a monoid and returns the cumulative sum of that monoid **/ def cumulativeSum( - implicit sg: Semigroup[V], ordU: Ordering[U], ordK: Ordering[K]): SortedGrouped[K, (U, V)] = { + implicit sg: Semigroup[V], + ordU: Ordering[U], + ordK: Ordering[K]): SortedGrouped[K, (U, V)] = { pipe.group - .sortBy { case (u: U, _) => u } + .sortBy { case (u, _) => u } .scanLeft(Nil: List[(U, V)]) { - case (acc, (u: U, v: V)) => + case (acc, (u, v)) => acc match { case List((previousU, previousSum)) => List((u, sg.plus(previousSum, v))) case _ => List((u, v)) @@ -63,7 +65,7 @@ object CumulativeSum { ordK: Ordering[K]): TypedPipe[(K, (U, V))] = { val sumPerS = pipe - .map { case (k, (u: U, v: V)) => (k, partition(u)) -> v } + .map { case (k, (u, v)) => (k, partition(u)) -> v } .sumByKey .map { case ((k, s), v) => (k, (s, v)) } .group @@ -87,7 +89,7 @@ object CumulativeSum { val summands = pipe .map { - case (k, (u: U, v: V)) => + case (k, (u, v)) => (k, partition(u)) -> (Some(u), v) } ++ sumPerS diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala b/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala index a4f3aba941..9430887f67 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala @@ -19,9 +19,10 @@ import com.twitter.algebird.monad.Reader import com.twitter.algebird.{ Monoid, Monad } import com.twitter.scalding.cascading_interop.FlowListenerPromise import com.twitter.scalding.Dsl.flowDefToRichFlowDef - -import scala.concurrent.{ Await, Future, Promise, ExecutionContext => ConcurrentExecutionContext } +import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue } +import scala.concurrent.{ Await, Future, ExecutionContext => ConcurrentExecutionContext, Promise } import scala.util.{ Failure, Success, Try } +import scala.util.control.NonFatal import cascading.flow.{ FlowDef, Flow } /** @@ -44,7 +45,7 @@ import cascading.flow.{ FlowDef, Flow } * zip to flatMap if you want to run two Executions in parallel. */ sealed trait Execution[+T] extends java.io.Serializable { - import Execution.{ emptyCache, EvalCache, FactoryExecution, FlatMapped, MapCounters, Mapped, OnComplete, RecoverWith, Zipped } + import Execution.{ EvalCache, FlatMapped, GetCounters, ResetCounters, Mapped, OnComplete, RecoverWith, Zipped } /** * Scala uses the filter method in for syntax for pattern matches that can fail. @@ -82,7 +83,7 @@ sealed trait Execution[+T] extends java.io.Serializable { * You may want .getAndResetCounters. */ def getCounters: Execution[(T, ExecutionCounters)] = - MapCounters[T, (T, ExecutionCounters)](this, { case tc @ (t, c) => (tc, c) }) + GetCounters(this) /** * Reads the counters and resets them to zero. Probably what @@ -94,11 +95,14 @@ sealed trait Execution[+T] extends java.io.Serializable { /** * This function is called when the current run is completed. This is - * a only a side effect (see unit return). - * Note, this is the only way to force a side effect. Map and FlatMap - * are not safe for side effects. ALSO You must run the result. If + * only a side effect (see unit return). + * + * ALSO You must .run the result. If * you throw away the result of this call, your fn will never be - * called. + * called. When you run the result, the Future you get will not + * be complete unless fn has completed running. If fn throws, it + * will be handled be the scala.concurrent.ExecutionContext.reportFailure + * NOT by returning a Failure in the Future. */ def onComplete(fn: Try[T] => Unit): Execution[T] = OnComplete(this, fn) @@ -118,7 +122,7 @@ sealed trait Execution[+T] extends java.io.Serializable { * you want to reset before a zip or a call to flatMap */ def resetCounters: Execution[T] = - MapCounters[T, T](this, { case (t, c) => (t, ExecutionCounters.empty) }) + ResetCounters(this) /** * This causes the Execution to occur. The result is not cached, so each call @@ -127,8 +131,15 @@ sealed trait Execution[+T] extends java.io.Serializable { * * Seriously: pro-style is for this to be called only once in a program. */ - def run(conf: Config, mode: Mode)(implicit cec: ConcurrentExecutionContext): Future[T] = - runStats(conf, mode, emptyCache)(cec)._2.map(_._1) + final def run(conf: Config, mode: Mode)(implicit cec: ConcurrentExecutionContext): Future[T] = { + val ec = new EvalCache + val result = runStats(conf, mode, ec)(cec).map(_._1) + // When the final future in complete we stop the submit thread + result.onComplete { _ => ec.finished() } + // wait till the end to start the thread in case the above throws + ec.start() + result + } /** * This is the internal method that must be implemented @@ -139,15 +150,11 @@ sealed trait Execution[+T] extends java.io.Serializable { */ protected def runStats(conf: Config, mode: Mode, - cache: EvalCache)(implicit cec: ConcurrentExecutionContext): (EvalCache, Future[(T, ExecutionCounters, EvalCache)]) + cache: EvalCache)(implicit cec: ConcurrentExecutionContext): Future[(T, ExecutionCounters)] /** * This is convenience for when we don't care about the result. * like .map(_ => ()) - * Note: When called after a map, the map never happens. Use onComplete - * to attach side effects. - * - * .map(fn).unit == .unit */ def unit: Execution[Unit] = map(_ => ()) @@ -157,10 +164,9 @@ sealed trait Execution[+T] extends java.io.Serializable { * composition. Every time someone calls this, be very suspect. It is * always code smell. Very seldom should you need to wait on a future. */ - def waitFor(conf: Config, mode: Mode): Try[T] = { + def waitFor(conf: Config, mode: Mode): Try[T] = Try(Await.result(run(conf, mode)(ConcurrentExecutionContext.global), scala.concurrent.duration.Duration.Inf)) - } /** * This is here to silence warnings in for comprehensions, but is @@ -173,11 +179,8 @@ sealed trait Execution[+T] extends java.io.Serializable { * run this and that in parallel, without any dependency. This will * be done in a single cascading flow if possible. */ - def zip[U](that: Execution[U]): Execution[(T, U)] = that match { - // push zips as low as possible - case fact @ FactoryExecution(_) => fact.zip(this).map(_.swap) - case _ => Zipped(this, that) - } + def zip[U](that: Execution[U]): Execution[(T, U)] = + Zipped(this, that) } /** @@ -196,132 +199,172 @@ object Execution { override def join[T, U](t: Execution[T], u: Execution[U]): Execution[(T, U)] = t.zip(u) } - trait EvalCache { self => + /** + * This is a mutable state that is kept internal to an execution + * as it is evaluating. + */ + private[scalding] class EvalCache { + private[this] val cache = + new ConcurrentHashMap[Execution[Any], Future[(Any, ExecutionCounters)]]() + /** - * For a given execution, return the EvalCache before the future is executed, - * and a Future of the result, counters, and cache after - * This takes care of merging the input cache with cache in the future - * result, so you don't need to worry about that (but it wouldn't be an - * error to add something to the cache twice clearly). + * We send messages from other threads into the submit thread here */ - def getOrElseInsert[T](ex: Execution[T], - res: => (EvalCache, Future[(T, ExecutionCounters, EvalCache)]))(implicit ec: ConcurrentExecutionContext): (EvalCache, Future[(T, ExecutionCounters, EvalCache)]) + sealed trait FlowDefAction + case class RunFlowDef(conf: Config, + mode: Mode, + fd: FlowDef, + result: Promise[JobStats]) extends FlowDefAction + case object Stop extends FlowDefAction + private val messageQueue = new LinkedBlockingQueue[FlowDefAction]() + /** + * Hadoop and/or cascading has some issues, it seems, with starting jobs + * from multiple threads. This thread does all the Flow starting. + */ + private val thread = new Thread(new Runnable { + def run() { + @annotation.tailrec + def go(): Unit = messageQueue.take match { + case Stop => () + case RunFlowDef(conf, mode, fd, promise) => + try { + promise.completeWith( + ExecutionContext.newContext(conf)(fd, mode).run) + } catch { + case t: Throwable => + // something bad happened, but this thread is a daemon + // that should only stop if all others have stopped or + // we have received the stop message. + // Stopping this thread prematurely can deadlock + // futures from the promise we have. + // In a sense, this thread does not exist logically and + // must forward all exceptions to threads that requested + // this work be started. + promise.tryFailure(t) + } + // Loop + go() + } + + // Now we actually run the recursive loop + go() + } + }) + + def runFlowDef(conf: Config, mode: Mode, fd: FlowDef): Future[JobStats] = + try { + val promise = Promise[JobStats]() + val fut = promise.future + messageQueue.put(RunFlowDef(conf, mode, fd, promise)) + // Don't do any work after the .put call, we want no chance for exception + // after the put + fut + } catch { + case NonFatal(e) => + Future.failed(e) + } - def ++(that: EvalCache): EvalCache = new EvalCache { - def getOrElseInsert[T](ex: Execution[T], - res: => (EvalCache, Future[(T, ExecutionCounters, EvalCache)]))(implicit ec: ConcurrentExecutionContext) = - that.getOrElseInsert(ex, self.getOrElseInsert(ex, res)) + def start(): Unit = { + // Make sure this thread can't keep us running if all others are gone + thread.setDaemon(true) + thread.start() } - } - /** - * This is an implementation that remembers history forever. - * Since Hadoop jobs are generally long running and not infinite loops, - * this is generally safe. If someone wants to make an infinite loop or giant loop, - * this may OOM. The solution might be use an immutable LRU cache. - */ - private case class MapEvalCache(cache: Map[Execution[_], Future[(_, ExecutionCounters, EvalCache)]]) extends EvalCache { - def getOrElseInsert[T](ex: Execution[T], res: => (EvalCache, Future[(T, ExecutionCounters, EvalCache)]))(implicit ec: ConcurrentExecutionContext) = cache.get(ex) match { - case None => - val (next, fut) = res - // Make sure ex is added to the cache: - val resCache = next ++ MapEvalCache(cache + (ex -> fut)) - /* - * Note in this branch, the future returned includes a - * next and the ex -> fut mapping - */ - (resCache, fut.map { case (t, ec, fcache) => (t, ec, resCache ++ fcache) }) - - case Some(fut) => - /* - * The future recorded here may not itself it it's inner cache - * (nothing else is ensuring it). So we make sure the same way we do above - */ - val typedFut = fut.asInstanceOf[Future[(T, ExecutionCounters, EvalCache)]] - (this, typedFut.map { case (t, ec, fcache) => (t, ec, this ++ fcache) }) - } - override def ++(that: EvalCache): EvalCache = that match { - case MapEvalCache(thatCache) => MapEvalCache(cache ++ thatCache) - case _ => super.++(that) + /* + * This is called after we are done submitting all jobs + */ + def finished(): Unit = messageQueue.put(Stop) + + def getOrElseInsert[T](ex: Execution[T], + res: => Future[(T, ExecutionCounters)])(implicit ec: ConcurrentExecutionContext): Future[(T, ExecutionCounters)] = { + /* + * Since we don't want to evaluate res twice, we make a promise + * which we will use if it has not already been evaluated + */ + val promise = Promise[(T, ExecutionCounters)]() + val fut = promise.future + cache.putIfAbsent(ex, fut) match { + case null => + // note res is by-name, so we just evaluate it now: + promise.completeWith(res) + fut + case exists => exists.asInstanceOf[Future[(T, ExecutionCounters)]] + } } } - private def emptyCache: EvalCache = MapEvalCache(Map.empty) private case class FutureConst[T](get: ConcurrentExecutionContext => Future[T]) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, { - val fft: Future[Future[T]] = toFuture(Try(get(cec))) - (cache, for { - futt <- fft + cache.getOrElseInsert(this, + for { + futt <- toFuture(Try(get(cec))) t <- futt - } yield (t, ExecutionCounters.empty, cache)) - }) + } yield (t, ExecutionCounters.empty)) // Note that unit is not optimized away, since Futures are often used with side-effects, so, // we ensure that get is always called in contrast to Mapped, which assumes that fn is pure. } private case class FlatMapped[S, T](prev: Execution[S], fn: S => Execution[T]) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, { - val (cache1, fut) = prev.runStats(conf, mode, cache) - val finalFut = for { - (s, st1, cache1a) <- fut + cache.getOrElseInsert(this, + for { + (s, st1) <- prev.runStats(conf, mode, cache) next = fn(s) - (_, fut2) = next.runStats(conf, mode, cache1a) - (t, st2, cache2a) <- fut2 - } yield (t, Monoid.plus(st1, st2), cache2a) - (cache1, finalFut) - }) + fut2 = next.runStats(conf, mode, cache) + (t, st2) <- fut2 + } yield (t, Monoid.plus(st1, st2))) } private case class Mapped[S, T](prev: Execution[S], fn: S => T) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, { - val (cache1, fut) = prev.runStats(conf, mode, cache) - (cache1, fut.map { case (s, stats, c) => (fn(s), stats, c) }) - }) - - // Don't bother applying the function if we are mapped - override def unit = prev.unit + cache.getOrElseInsert(this, + prev.runStats(conf, mode, cache) + .map { case (s, stats) => (fn(s), stats) }) } - private case class MapCounters[T, U](prev: Execution[T], - fn: ((T, ExecutionCounters)) => (U, ExecutionCounters)) extends Execution[U] { + private case class GetCounters[T](prev: Execution[T]) extends Execution[(T, ExecutionCounters)] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, { - val (cache1, fut) = prev.runStats(conf, mode, cache) - (cache1, fut.map { - case (t, counters, c) => - val (u, counters2) = fn((t, counters)) - (u, counters2, c) - }) - }) + cache.getOrElseInsert(this, + prev.runStats(conf, mode, cache).map { case tc @ (t, c) => (tc, c) }) } + private case class ResetCounters[T](prev: Execution[T]) extends Execution[T] { + def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = + cache.getOrElseInsert(this, + prev.runStats(conf, mode, cache).map { case (t, _) => (t, ExecutionCounters.empty) }) + } + private case class OnComplete[T](prev: Execution[T], fn: Try[T] => Unit) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = cache.getOrElseInsert(this, { val res = prev.runStats(conf, mode, cache) - res._2.map(_._1).onComplete(fn) - res + /** + * The result we give is only completed AFTER fn is run + * so callers can wait on the result of this OnComplete + */ + val finished = Promise[(T, ExecutionCounters)]() + res.onComplete { tryT => + try { + fn(tryT.map(_._1)) + } finally { + // Do our best to signal when we are done + finished.complete(tryT) + } + } + finished.future }) } private case class RecoverWith[T](prev: Execution[T], fn: PartialFunction[Throwable, Execution[T]]) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, { - val (cache1, fut) = prev.runStats(conf, mode, cache) - // Note, if the future fails, we restart from the input cache - (cache1, fut.recoverWith(fn.andThen(_.runStats(conf, mode, cache)._2))) - }) + cache.getOrElseInsert(this, + prev.runStats(conf, mode, cache) + .recoverWith(fn.andThen(_.runStats(conf, mode, cache)))) } private case class Zipped[S, T](one: Execution[S], two: Execution[T]) extends Execution[(S, T)] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = cache.getOrElseInsert(this, { - val (cache1, f1) = one.runStats(conf, mode, cache) - val (cache2, f2) = two.runStats(conf, mode, cache1) - (cache2, f1.zip(f2) - .map { case ((s, ss, c1a), (t, st, c2a)) => ((s, t), Monoid.plus(ss, st), c1a ++ c2a) }) + val f1 = one.runStats(conf, mode, cache) + val f2 = two.runStats(conf, mode, cache) + f1.zip(f2) + .map { case ((s, ss), (t, st)) => ((s, t), Monoid.plus(ss, st)) } }) - - // Make sure we remove any mapping functions on both sides - override def unit = one.unit.zip(two.unit).map(_ => ()) } private case class UniqueIdExecution[T](fn: UniqueID => Execution[T]) extends Execution[T] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = @@ -331,68 +374,52 @@ object Execution { }) } /* - * This is the main class the represents a flow without any combinators + * This allows you to run any cascading flowDef as an Execution. */ - private case class FlowDefExecution[T](result: (Config, Mode) => (FlowDef, (JobStats => Future[T]))) extends Execution[T] { + private case class FlowDefExecution(result: (Config, Mode) => FlowDef) extends Execution[Unit] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = cache.getOrElseInsert(this, - (cache, for { - (flowDef, fn) <- toFuture(Try(result(conf, mode))) + for { + flowDef <- toFuture(Try(result(conf, mode))) _ = FlowStateMap.validateSources(flowDef, mode) - jobStats <- ExecutionContext.newContext(conf)(flowDef, mode).run + jobStats <- cache.runFlowDef(conf, mode, flowDef) _ = FlowStateMap.clear(flowDef) - t <- fn(jobStats) - } yield (t, ExecutionCounters.fromJobStats(jobStats), cache))) + } yield ((), ExecutionCounters.fromJobStats(jobStats))) + } - /* - * Cascading can run parallel Executions in the same flow if they are both FlowDefExecutions - */ - override def zip[U](that: Execution[U]): Execution[(T, U)] = - that match { - /* - * This merging parallelism only works if the names of the - * sources are distinct. Scalding allocates uuids to each - * pipe that starts a head, so a collision should be HIGHLY - * unlikely. - */ - case FlowDefExecution(result2) => - FlowDefExecution({ (conf, m) => - val (fd1, fn1) = result(conf, m) - val (fd2, fn2) = result2(conf, m) - val merged = fd1.copy - merged.mergeFrom(fd2) - (merged, { (js: JobStats) => fn1(js).zip(fn2(js)) }) - }) - case _ => super.zip(that) - } + /* + * This is here so we can call without knowing the type T + * but with proof that pipe matches sink + */ + private case class ToWrite[T](pipe: TypedPipe[T], sink: TypedSink[T]) { + def write(flowDef: FlowDef, mode: Mode): Unit = { + // This has the side effect of mutating flowDef + pipe.write(sink)(flowDef, mode) + () + } } - private case class FactoryExecution[T](result: (Config, Mode) => Execution[T]) extends Execution[T] { + /** + * This is the fundamental execution that actually happens in TypedPipes, all the rest + * are based on on this one. By keeping the Pipe and the Sink, can inspect the Execution + * DAG and optimize it later (a goal, but not done yet). + */ + private case class WriteExecution(head: ToWrite[_], tail: List[ToWrite[_]]) extends Execution[Unit] { def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = - cache.getOrElseInsert(this, unwrap(conf, mode, this).runStats(conf, mode, cache)) + cache.getOrElseInsert(this, + for { + flowDef <- toFuture(Try { val fd = new FlowDef; (head :: tail).foreach(_.write(fd, mode)); fd }) + _ = FlowStateMap.validateSources(flowDef, mode) + jobStats <- cache.runFlowDef(conf, mode, flowDef) + _ = FlowStateMap.clear(flowDef) + } yield ((), ExecutionCounters.fromJobStats(jobStats))) + } - @annotation.tailrec - private def unwrap[U](conf: Config, mode: Mode, that: Execution[U]): Execution[U] = - that match { - case FactoryExecution(fn) => unwrap(conf, mode, fn(conf, mode)) - case nonFactory => nonFactory - } - /* - * Cascading can run parallel Executions in the same flow if they are both FlowDefExecutions - */ - override def zip[U](that: Execution[U]): Execution[(T, U)] = - that match { - case FactoryExecution(result2) => - FactoryExecution({ (conf, m) => - val exec1 = unwrap(conf, m, result(conf, m)) - val exec2 = unwrap(conf, m, result2(conf, m)) - exec1.zip(exec2) - }) - case _ => - FactoryExecution({ (conf, m) => - val exec1 = unwrap(conf, m, result(conf, m)) - exec1.zip(that) - }) - } + /** + * This is called Reader, because it just returns its input to run as the output + */ + private case object ReaderExecution extends Execution[(Config, Mode)] { + def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = + Future.successful(((conf, mode), ExecutionCounters.empty)) } private def toFuture[R](t: Try[R]): Future[R] = @@ -404,15 +431,15 @@ object Execution { /** * This creates a definitely failed Execution. */ - def failed(t: Throwable): Execution[Nothing] = - fromFuture(_ => Future.failed(t)) + def failed(t: Throwable): Execution[Nothing] = fromTry(Failure(t)) /** * This makes a constant execution that runs no job. * Note this is a lazy parameter that is evaluated every * time run is called. */ - def from[T](t: => T): Execution[T] = fromFuture { _ => toFuture(Try(t)) } + def from[T](t: => T): Execution[T] = fromTry(Try(t)) + def fromTry[T](t: => Try[T]): Execution[T] = fromFuture { _ => toFuture(t) } /** * The call to fn will happen when the run method on the result is called. @@ -423,29 +450,44 @@ object Execution { */ def fromFuture[T](fn: ConcurrentExecutionContext => Future[T]): Execution[T] = FutureConst(fn) - private[scalding] def factory[T](fn: (Config, Mode) => Execution[T]): Execution[T] = - FactoryExecution(fn) + /** Returns a constant Execution[Unit] */ + val unit: Execution[Unit] = from(()) /** * This converts a function into an Execution monad. The flowDef returned - * is never mutated. The returned callback funcion is called after the flow - * is run and succeeds. + * is never mutated. */ - def fromFn[T]( - fn: (Config, Mode) => ((FlowDef, JobStats => Future[T]))): Execution[T] = + def fromFn(fn: (Config, Mode) => FlowDef): Execution[Unit] = FlowDefExecution(fn) + /** + * Creates an Execution to do a write + */ + private[scalding] def write[T](pipe: TypedPipe[T], sink: TypedSink[T]): Execution[Unit] = + WriteExecution(ToWrite(pipe, sink), Nil) + + /** + * Convenience method to get the Args + */ + def getArgs: Execution[Args] = ReaderExecution.map(_._1.getArgs) /** * Use this to read the configuration, which may contain Args or options * which describe input on which to run */ - def getConfig: Execution[Config] = factory { case (conf, _) => from(conf) } + def getConfig: Execution[Config] = ReaderExecution.map(_._1) /** Use this to get the mode, which may contain the job conf */ - def getMode: Execution[Mode] = factory { case (_, mode) => from(mode) } + def getMode: Execution[Mode] = ReaderExecution.map(_._2) /** Use this to get the config and mode. */ - def getConfigMode: Execution[(Config, Mode)] = factory { case (conf, mode) => from((conf, mode)) } + def getConfigMode: Execution[(Config, Mode)] = ReaderExecution + + /** + * This is convenience method only here to make it slightly cleaner + * to get Args, which are in the Config + */ + def withArgs[T](fn: Args => Execution[T]): Execution[T] = + getConfig.flatMap { conf => fn(conf.getArgs) } /** * Use this to use counters/stats with Execution. You do this: diff --git a/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala b/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala index bf6619d6ce..5080256923 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala @@ -17,6 +17,7 @@ package com.twitter.scalding import cascading.flow.{ FlowDef, Flow } import com.twitter.scalding.reducer_estimation.ReducerEstimatorStepStrategy +import com.twitter.scalding.serialization.CascadingBinaryComparator import scala.concurrent.Future import scala.util.{ Failure, Success, Try } @@ -44,6 +45,11 @@ trait ExecutionContext { try { // identify the flowDef val withId = config.addUniqueId(UniqueID.getIDFor(flowDef)) + if (config.getRequireOrderedSerialization) { + // This will throw, but be caught by the outer try if + // we have groupby/cogroupby not using OrderedSerializations + CascadingBinaryComparator.checkForOrderedSerialization(flowDef).get + } val flow = mode.newFlowConnector(withId).connect(flowDef) // if any reducer estimators have been set, register the step strategy diff --git a/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala index 078cb7e803..60885c1311 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala @@ -16,15 +16,13 @@ limitations under the License. package com.twitter.scalding import cascading.tuple.Fields - -import scala.collection.JavaConversions._ - import cascading.pipe.Pipe -import scala.annotation.tailrec -import java.util.Comparator - import com.esotericsoftware.kryo.DefaultSerializer +import java.util.Comparator +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + trait LowPriorityFieldConversions { protected def anyToFieldArg(f: Any): Comparable[_] = f match { @@ -67,7 +65,7 @@ trait FieldConversions extends LowPriorityFieldConversions { // Cascading Fields are either java.lang.String or java.lang.Integer, both are comparable. def asList(f: Fields): List[Comparable[_]] = { - f.iterator.toList.asInstanceOf[List[Comparable[_]]] + f.iterator.asScala.toList.asInstanceOf[List[Comparable[_]]] } // Cascading Fields are either java.lang.String or java.lang.Integer, both are comparable. def asSet(f: Fields): Set[Comparable[_]] = asList(f).toSet @@ -75,7 +73,7 @@ trait FieldConversions extends LowPriorityFieldConversions { // TODO get the comparator also def getField(f: Fields, idx: Int): Fields = { new Fields(f.get(idx)) } - def hasInts(f: Fields): Boolean = f.iterator.exists { _.isInstanceOf[java.lang.Integer] } + def hasInts(f: Fields): Boolean = f.iterator.asScala.exists { _.isInstanceOf[java.lang.Integer] } /** * Rather than give the full power of cascading's selectors, we have @@ -268,4 +266,10 @@ object Field { def apply[T](index: Int)(implicit ord: Ordering[T], mf: Manifest[T]) = IntField[T](index)(ord, Some(mf)) def apply[T](name: String)(implicit ord: Ordering[T], mf: Manifest[T]) = StringField[T](name)(ord, Some(mf)) def apply[T](symbol: Symbol)(implicit ord: Ordering[T], mf: Manifest[T]) = StringField[T](symbol.name)(ord, Some(mf)) + + def singleOrdered[T](name: String)(implicit ord: Ordering[T]): Fields = { + val f = new Fields(name) + f.setComparator(name, ord) + f + } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala index e783525be8..ce75e8b5f9 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala @@ -15,7 +15,7 @@ limitations under the License. */ package com.twitter.scalding -import java.io.{ InputStream, OutputStream } +import java.io.{ File, InputStream, OutputStream } import java.util.{ UUID, Properties } import cascading.scheme.Scheme @@ -59,12 +59,21 @@ abstract class SchemedSource extends Source { val sinkMode: SinkMode = SinkMode.REPLACE } +private[scalding] object CastFileTap { + // The scala compiler has problems with the generics in Cascading + def apply(tap: FileTap): Tap[JobConf, RecordReader[_, _], OutputCollector[_, _]] = + tap.asInstanceOf[Tap[JobConf, RecordReader[_, _], OutputCollector[_, _]]] +} + /** * A trait which provides a method to create a local tap. */ trait LocalSourceOverride extends SchemedSource { /** A path to use for the local tap. */ - def localPath: String + def localPaths: Iterable[String] + + // By default, we write to the last path for local paths + def localWritePath = localPaths.last /** * Creates a local tap. @@ -72,7 +81,18 @@ trait LocalSourceOverride extends SchemedSource { * @param sinkMode The mode for handling output conflicts. * @returns A tap. */ - def createLocalTap(sinkMode: SinkMode): Tap[_, _, _] = new FileTap(localScheme, localPath, sinkMode) + def createLocalTap(sinkMode: SinkMode): Tap[JobConf, _, _] = { + val taps = localPaths.map { + p: String => + CastFileTap(new FileTap(localScheme, p, sinkMode)) + }.toList + + taps match { + case Nil => throw new InvalidSourceException("LocalPaths is empty") + case oneTap :: Nil => oneTap + case many => new ScaldingMultiSourceTap(many) + } + } } object HiddenFileFilter extends PathFilter { @@ -145,7 +165,10 @@ abstract class FileSource extends SchemedSource with LocalSourceOverride { mode match { // TODO support strict in Local case Local(_) => { - createLocalTap(sinkMode) + readOrWrite match { + case Read => createLocalTap(sinkMode) + case Write => new FileTap(localScheme, localWritePath, sinkMode) + } } case hdfsMode @ Hdfs(_, _) => readOrWrite match { case Read => createHdfsReadTap(hdfsMode) @@ -195,6 +218,17 @@ abstract class FileSource extends SchemedSource with LocalSourceOverride { "[" + this.toString + "] No good paths in: " + hdfsPaths.toString) } } + + case Local(strict) => { + val files = localPaths.map{ p => new java.io.File(p) } + if (strict && !files.forall(_.exists)) { + throw new InvalidSourceException( + "[" + this.toString + s"] Data is missing from: ${localPaths.filterNot { p => new java.io.File(p).exists }}") + } else if (!files.exists(_.exists)) { + throw new InvalidSourceException( + "[" + this.toString + "] No good paths in: " + hdfsPaths.toString) + } + } case _ => () } } @@ -306,12 +340,23 @@ trait SuccessFileSource extends FileSource { * Put another way, this runs a Hadoop tap outside of Hadoop in the Cascading local mode */ trait LocalTapSource extends LocalSourceOverride { - override def createLocalTap(sinkMode: SinkMode) = new LocalTap(localPath, hdfsScheme, sinkMode).asInstanceOf[Tap[_, _, _]] + override def createLocalTap(sinkMode: SinkMode): Tap[JobConf, _, _] = { + val taps = localPaths.map { p => + new LocalTap(p, hdfsScheme, sinkMode).asInstanceOf[Tap[JobConf, RecordReader[_, _], OutputCollector[_, _]]] + }.toSeq + + taps match { + case Nil => throw new InvalidSourceException("LocalPaths is empty") + case oneTap :: Nil => oneTap + case many => new ScaldingMultiSourceTap(many) + } + } } abstract class FixedPathSource(path: String*) extends FileSource { - def localPath = { assert(path.size == 1, "Cannot use multiple input files on local mode"); path(0) } + def localPaths = path.toList def hdfsPaths = path.toList + override def toString = getClass.getName + path override def hashCode = toString.hashCode override def equals(that: Any): Boolean = (that != null) && (that.toString == toString) @@ -355,7 +400,7 @@ case class Osv(p: String, f: Fields = Fields.ALL, override val sinkMode: SinkMode = SinkMode.REPLACE) extends FixedPathSource(p) with DelimitedScheme { override val fields = f - override val separator = "\1" + override val separator = "\u0001" } object TextLine { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/FlowState.scala b/scalding-core/src/main/scala/com/twitter/scalding/FlowState.scala index ccc7ab32a5..85b6b74718 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/FlowState.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/FlowState.scala @@ -15,16 +15,18 @@ limitations under the License. */ package com.twitter.scalding -import cascading.pipe.Pipe import cascading.flow.FlowDef -import java.util.{ Map => JMap, WeakHashMap } -import scala.collection.JavaConverters._ +import java.util.WeakHashMap + /** * Immutable state that we attach to the Flow using the FlowStateMap */ -case class FlowState(sourceMap: Map[String, Source] = Map.empty) { +case class FlowState(sourceMap: Map[String, Source] = Map.empty, flowConfigUpdates: Set[(String, String)] = Set()) { def addSource(id: String, s: Source): FlowState = - FlowState(sourceMap + (id -> s)) + copy(sourceMap = sourceMap + (id -> s)) + + def addConfigSetting(k: String, v: String): FlowState = + copy(flowConfigUpdates = flowConfigUpdates + ((k, v))) def getSourceNamed(name: String): Option[Source] = sourceMap.get(name) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala b/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala index 0b6d40409c..e1c7d518e6 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala @@ -17,14 +17,9 @@ package com.twitter.scalding import cascading.pipe._ import cascading.pipe.assembly._ import cascading.operation._ -import cascading.operation.aggregator._ -import cascading.operation.filter._ import cascading.tuple.Fields -import cascading.tuple.{ Tuple => CTuple, TupleEntry } +import cascading.tuple.TupleEntry -import scala.collection.JavaConverters._ -import scala.annotation.tailrec -import scala.math.Ordering import scala.{ Range => ScalaRange } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/IntegralComparator.scala b/scalding-core/src/main/scala/com/twitter/scalding/IntegralComparator.scala index a95de95bea..29b888854f 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/IntegralComparator.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/IntegralComparator.scala @@ -14,12 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -package com.twitter.scalding; +package com.twitter.scalding -import cascading.tuple.Hasher; +import cascading.tuple.Hasher -import java.io.Serializable; -import java.util.Comparator; +import java.io.Serializable +import java.util.Comparator /* * Handles numerical hashing properly diff --git a/scalding-core/src/main/scala/com/twitter/scalding/IterableSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/IterableSource.scala index 2ae8045aed..9a575bd9dd 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/IterableSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/IterableSource.scala @@ -17,9 +17,6 @@ package com.twitter.scalding import com.twitter.maple.tap.MemorySourceTap -import cascading.flow.FlowProcess -import cascading.scheme.local.{ TextDelimited => CLTextDelimited } -import cascading.scheme.Scheme import cascading.tap.Tap import cascading.tuple.Tuple import cascading.tuple.Fields @@ -27,10 +24,6 @@ import cascading.scheme.NullScheme import java.io.{ InputStream, OutputStream } -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCollector -import org.apache.hadoop.mapred.RecordReader - import scala.collection.mutable.Buffer import scala.collection.JavaConverters._ diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala index 076fd3979b..d4363f7af9 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala @@ -16,27 +16,19 @@ limitations under the License. package com.twitter.scalding import com.twitter.algebird.monad.Reader -import com.twitter.chill.config.{ ScalaAnyRefMapConfig, ConfiguredInstantiator } -import cascading.pipe.assembly.AggregateBy -import cascading.flow.{ Flow, FlowDef, FlowProps, FlowListener, FlowStep, FlowStepListener, FlowSkipStrategy, FlowStepStrategy } +import cascading.flow.{ Flow, FlowDef, FlowListener, FlowStep, FlowStepListener, FlowSkipStrategy, FlowStepStrategy } import cascading.pipe.Pipe import cascading.property.AppProps -import cascading.tuple.collect.SpillableProps import cascading.stats.CascadingStats -import com.twitter.scalding.reducer_estimation.EstimatorConfig import org.apache.hadoop.io.serializer.{ Serialization => HSerialization } -import org.apache.hadoop.mapred.JobConf -import org.slf4j.LoggerFactory -//For java -> scala implicits on collections -import scala.collection.JavaConversions._ import scala.concurrent.{ Future, Promise } import scala.util.Try -import java.io.{ BufferedWriter, File, FileOutputStream, OutputStreamWriter } -import java.util.{ Calendar, UUID, List => JList } +import java.io.{ BufferedWriter, FileOutputStream, OutputStreamWriter } +import java.util.{ List => JList } import java.util.concurrent.{ Executors, TimeUnit, ThreadFactory, Callable, TimeoutException } import java.util.concurrent.atomic.AtomicInteger @@ -116,6 +108,9 @@ class Job(val args: Args) extends FieldConversions with java.io.Serializable { implicit def iterableToRichPipe[T](iter: Iterable[T])(implicit set: TupleSetter[T], conv: TupleConverter[T]): RichPipe = RichPipe(toPipe(iter)(set, conv)) + // Provide args as an implicit val for extensions such as the Checkpoint extension. + implicit protected def _implicitJobArgs: Args = args + // Override this if you want to change how the mapred.job.name is written in Hadoop def name: String = Config.defaultFrom(mode).toMap.getOrElse("mapred.job.name", getClass.getName) @@ -475,41 +470,6 @@ abstract class ExecutionJob[+T](args: Args) extends Job(args) { } } -/* - * this allows you to use ExecutionContext style, but wrap it in a job - * val ecFn = { (implicit ec: ExecutionContext) => - * // do stuff here - * }; - * class MyClass(args: Args) extends ExecutionContextJob(args) { - * def job = ecFn - * } - * Now you can run it with Tool as a standard Job-framework style. - * Only use this if you have an existing ExecutionContext style function - * you want to run as a Job - */ -@deprecated("Use ExecutionJob", "2014-07-29") -abstract class ExecutionContextJob[+T](args: Args) extends Job(args) { - /** - * This can be assigned from a Function1: - * def job = (ectxJob: (ExecutionContext => T)) - */ - def job: Reader[ExecutionContext, T] - /** - * This is the result of calling the job on the context for this job - * you should NOT call this in the job Reader (or reference this class at all - * in reader - */ - @transient final lazy val result: Try[T] = ec.map(job(_)) // mutate the flowDef with the job - - private[this] final def ec: Try[ExecutionContext] = - Config.tryFrom(config).map { conf => ExecutionContext.newContext(conf)(flowDef, mode) } - - override def buildFlow: Flow[_] = { - val forcedResult = result.get // make sure we have applied job once - super.buildFlow - } -} - /* * Run a list of shell commands through bash in the given order. Return success * when all commands succeed. Excution stops after the first failure. The diff --git a/scalding-core/src/main/scala/com/twitter/scalding/JobStats.scala b/scalding-core/src/main/scala/com/twitter/scalding/JobStats.scala index 5529f76816..dac1f1a720 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/JobStats.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/JobStats.scala @@ -15,9 +15,7 @@ limitations under the License. */ package com.twitter.scalding -import java.io.{ File, OutputStream } import scala.collection.JavaConverters._ -import cascading.flow.Flow import cascading.stats.{ CascadeStats, CascadingStats, FlowStats } import scala.util.{ Failure, Try } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/JoinAlgorithms.scala b/scalding-core/src/main/scala/com/twitter/scalding/JoinAlgorithms.scala index 19d60874f6..bcc6e29a70 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/JoinAlgorithms.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/JoinAlgorithms.scala @@ -15,18 +15,11 @@ limitations under the License. */ package com.twitter.scalding -import cascading.tap._ -import cascading.scheme._ import cascading.pipe._ -import cascading.pipe.assembly._ import cascading.pipe.joiner._ -import cascading.flow._ -import cascading.operation._ -import cascading.operation.aggregator._ -import cascading.operation.filter._ import cascading.tuple._ -import cascading.cascade._ +import java.util.{ Iterator => JIterator } import java.util.Random // this one is serializable, scala.util.Random is not import scala.collection.JavaConverters._ @@ -120,7 +113,7 @@ trait JoinAlgorithms { /** * Flip between LeftJoin to RightJoin */ - private def flipJoiner(j: Joiner) = { + private def flipJoiner(j: Joiner): Joiner = { j match { case outer: OuterJoin => outer case inner: InnerJoin => inner @@ -224,17 +217,17 @@ trait JoinAlgorithms { def joinWithTiny(fs: (Fields, Fields), that: Pipe) = { val intersection = asSet(fs._1).intersect(asSet(fs._2)) if (intersection.isEmpty) { - new HashJoin(assignName(pipe), fs._1, assignName(that), fs._2, new InnerJoin) + new HashJoin(assignName(pipe), fs._1, assignName(that), fs._2, WrappedJoiner(new InnerJoin)) } else { val (renamedThat, newJoinFields, temp) = renameCollidingFields(that, fs._2, intersection) - (new HashJoin(assignName(pipe), fs._1, assignName(renamedThat), newJoinFields, new InnerJoin)) + (new HashJoin(assignName(pipe), fs._1, assignName(renamedThat), newJoinFields, WrappedJoiner(new InnerJoin))) .discard(temp) } } def leftJoinWithTiny(fs: (Fields, Fields), that: Pipe) = { //Rename these pipes to avoid cascading name conflicts - new HashJoin(assignName(pipe), fs._1, assignName(that), fs._2, new LeftJoin) + new HashJoin(assignName(pipe), fs._1, assignName(that), fs._2, WrappedJoiner(new LeftJoin)) } /** @@ -381,6 +374,7 @@ trait JoinAlgorithms { (otherPipe, fs._2, Fields.NONE) else // For now, we are assuming an inner join. renameCollidingFields(otherPipe, fs._2, intersection) + val mergedJoinKeys = Fields.join(fs._1, rightResolvedJoinFields) // 1. First, get an approximate count of the left join keys and the right join keys, so that we // know how much to replicate. @@ -399,7 +393,29 @@ trait JoinAlgorithms { val sampledRight = rightPipe.sample(sampleRate, Seed) .groupBy(rightResolvedJoinFields) { _.size(rightSampledCountField) } val sampledCounts = sampledLeft.joinWithSmaller(fs._1 -> rightResolvedJoinFields, sampledRight, joiner = new OuterJoin) - .project(Fields.join(fs._1, rightResolvedJoinFields, sampledCountFields)) + .project(Fields.join(mergedJoinKeys, sampledCountFields)) + .map(mergedJoinKeys -> mergedJoinKeys) { t: cascading.tuple.Tuple => + // Make the outer join look like an inner join so that we can join + // either the left or right fields for every entry. + // Accomplished by replacing any null field with the corresponding + // field from the other half. E.g., + // (1, 2, "foo", null, null, null) -> (1, 2, "foo", 1, 2, "foo") + val keysSize = t.size / 2 + val result = new cascading.tuple.Tuple(t) + + for (index <- 0 until keysSize) { + val leftValue = result.getObject(index) + val rightValue = result.getObject(index + keysSize) + + if (leftValue == null) { + result.set(index, rightValue) + } else if (rightValue == null) { + result.set(index + keysSize, leftValue) + } + } + + result + } // 2. Now replicate each group of join keys in the left and right pipes, according to the sampled counts // from the previous step. @@ -474,3 +490,34 @@ trait JoinAlgorithms { } class InvalidJoinModeException(args: String) extends Exception(args) + +/** + * Wraps a Joiner instance so that the active FlowProcess may be noted. This allows features of Scalding that need + * access to a FlowProcess (e.g., counters) to function properly inside a Joiner. + */ +private[scalding] class WrappedJoiner(val joiner: Joiner) extends Joiner { + override def getIterator(joinerClosure: JoinerClosure): JIterator[Tuple] = { + RuntimeStats.addFlowProcess(joinerClosure.getFlowProcess) + joiner.getIterator(joinerClosure) + } + + override def numJoins(): Int = joiner.numJoins() + + override def hashCode(): Int = joiner.hashCode() + + override def toString: String = joiner.toString + + override def equals(other: Any): Boolean = joiner.equals(other) +} + +private[scalding] object WrappedJoiner { + /** + * Wrap the given Joiner in a WrappedJoiner instance if it is not already wrapped. + */ + def apply(joiner: Joiner): WrappedJoiner = { + joiner match { + case wrapped: WrappedJoiner => wrapped + case _ => new WrappedJoiner(joiner) + } + } +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/MemoryTap.scala b/scalding-core/src/main/scala/com/twitter/scalding/MemoryTap.scala index 5793e87aa4..896c63496a 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/MemoryTap.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/MemoryTap.scala @@ -18,10 +18,11 @@ package com.twitter.scalding import cascading.tap.Tap import java.util.Properties import cascading.tuple._ -import scala.collection.JavaConversions._ import cascading.scheme.Scheme import cascading.flow.FlowProcess -import collection.mutable.{ Buffer, MutableList } + +import scala.collection.mutable.Buffer +import scala.collection.JavaConverters._ class MemoryTap[In, Out](val scheme: Scheme[Properties, In, Out, _, _], val tupleBuffer: Buffer[Tuple]) extends Tap[Properties, In, Out](scheme) { @@ -44,7 +45,7 @@ class MemoryTap[In, Out](val scheme: Scheme[Properties, In, Out, _, _], val tupl override lazy val getIdentifier: String = scala.math.random.toString override def openForRead(flowProcess: FlowProcess[Properties], input: In) = { - new TupleEntryChainIterator(scheme.getSourceFields, tupleBuffer.toIterator) + new TupleEntryChainIterator(scheme.getSourceFields, tupleBuffer.toIterator.asJava) } override def openForWrite(flowProcess: FlowProcess[Properties], output: Out): TupleEntryCollector = { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala b/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala index e6c73f6f58..71e539fc3d 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala @@ -16,18 +16,17 @@ limitations under the License. package com.twitter.scalding import java.io.File -import java.util.{ Map => JMap, UUID, Properties } +import java.util.{ UUID, Properties } import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{ FileSystem, Path } import org.apache.hadoop.mapred.JobConf -import cascading.flow.{ FlowConnector, FlowDef, Flow } +import cascading.flow.FlowConnector import cascading.flow.hadoop.HadoopFlowProcess import cascading.flow.hadoop.HadoopFlowConnector import cascading.flow.local.LocalFlowConnector import cascading.flow.local.LocalFlowProcess -import cascading.pipe.Pipe import cascading.property.AppProps import cascading.tap.Tap import cascading.tuple.Tuple @@ -38,10 +37,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.Buffer import scala.collection.mutable.{ Map => MMap } import scala.collection.mutable.{ Set => MSet } -import scala.collection.mutable.{ Iterable => MIterable } -import scala.util.{ Failure, Success, Try } +import scala.util.{ Failure, Success } -import org.slf4j.{ Logger, LoggerFactory } +import org.slf4j.LoggerFactory case class ModeException(message: String) extends RuntimeException(message) @@ -144,6 +142,7 @@ trait CascadingLocal extends Mode { config.toMap.foreach { case (k, v) => props.setProperty(k, v) } val fp = new LocalFlowProcess(props) ltap.retrieveSourceFields(fp) + ltap.sourceConfInit(fp, props) ltap.openForRead(fp) } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala index 69917de209..7e5a1a9507 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala @@ -19,14 +19,9 @@ package com.twitter.scalding { import cascading.tuple._ import cascading.flow._ import cascading.pipe.assembly.AggregateBy - import cascading.pipe._ import com.twitter.chill.MeatLocker import scala.collection.JavaConverters._ - import org.apache.hadoop.conf.Configuration - - import com.esotericsoftware.kryo.Kryo; - import com.twitter.algebird.{ Semigroup, SummingCache } import com.twitter.scalding.mathematics.Poisson import serialization.Externalizer @@ -483,6 +478,7 @@ package com.twitter.scalding { /** In the typed API every reduce operation is handled by this Buffer */ class TypedBufferOp[K, V, U]( + conv: TupleConverter[K], @transient reduceFn: (K, Iterator[V]) => Iterator[U], valueField: Fields) extends BaseOperation[Any](valueField) with Buffer[Any] with ScaldingPrepare[Any] { @@ -490,7 +486,7 @@ package com.twitter.scalding { def operate(flowProcess: FlowProcess[_], call: BufferCall[Any]) { val oc = call.getOutputCollector - val key = call.getGroup.getObject(0).asInstanceOf[K] + val key = conv(call.getGroup) val values = call.getArgumentsIterator .asScala .map(_.getObject(0).asInstanceOf[V]) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/PartitionSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/PartitionSource.scala index 84d8b4f133..fac95fefdd 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/PartitionSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/PartitionSource.scala @@ -15,13 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.OutputCollector - -import cascading.scheme.hadoop.{ TextDelimited => CHTextDelimited } -import cascading.scheme.hadoop.TextLine.Compress -import cascading.scheme.Scheme import cascading.tap.hadoop.Hfs import cascading.tap.hadoop.{ PartitionTap => HPartitionTap } import cascading.tap.local.FileTap diff --git a/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala b/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala index 40f66917c6..d6f89c1f32 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala @@ -16,10 +16,9 @@ limitations under the License. package com.twitter.scalding import cascading.tuple.Fields -import cascading.tuple.{ Tuple => CTuple, TupleEntry } +import cascading.tuple.{ Tuple => CTuple } import com.twitter.algebird.{ - Monoid, Semigroup, Ring, AveragedValue, @@ -107,15 +106,6 @@ trait ReduceOperations[+Self <: ReduceOperations[Self]] extends java.io.Serializ hyperLogLogMap[T, HLL](f, errPercent) { hll => hll } } - @deprecated("use of approximateUniqueCount is preferred.", "0.8.3") - def approxUniques(f: (Fields, Fields), errPercent: Double = 1.0) = { - // Legacy (pre-bijection) approximate unique count that uses in.toString.getBytes to - // obtain a long hash code. We specify the kludgy CTuple => Array[Byte] bijection - // explicitly. - implicit def kludgeHasher(in: CTuple) = in.toString.getBytes("UTF-8") - hyperLogLogMap[CTuple, Double](f, errPercent) { _.estimatedSize } - } - private[this] def hyperLogLogMap[T <% Array[Byte]: TupleConverter, U: TupleSetter](f: (Fields, Fields), errPercent: Double = 1.0)(fn: HLL => U) = { //bits = log(m) == 2 *log(104/errPercent) = 2log(104) - 2*log(errPercent) def log2(x: Double) = scala.math.log(x) / scala.math.log(2.0) @@ -321,17 +311,6 @@ trait ReduceOperations[+Self <: ReduceOperations[Self]] extends java.io.Serializ def sum[T](fs: Symbol*)(implicit sg: Semigroup[T], tconv: TupleConverter[T], tset: TupleSetter[T]): Self = sum[T](fs -> fs)(sg, tconv, tset) - @deprecated("Use sum", "0.9.0") - def plus[T](fd: (Fields, Fields))(implicit sg: Semigroup[T], tconv: TupleConverter[T], tset: TupleSetter[T]): Self = - sum[T](fd)(sg, tconv, tset) - /** - * The same as `plus(fs -> fs)` - * Assumed to be a commutative operation. If you don't want that, use .forceToReducers - */ - @deprecated("Use sum", "0.9.0") - def plus[T](fs: Symbol*)(implicit sg: Semigroup[T], tconv: TupleConverter[T], tset: TupleSetter[T]): Self = - sum[T](fs -> fs)(sg, tconv, tset) - /** * Returns the product of all the items in this grouping */ diff --git a/scalding-core/src/main/scala/com/twitter/scalding/RichFlowDef.scala b/scalding-core/src/main/scala/com/twitter/scalding/RichFlowDef.scala index 8f0b6edfc3..4a1f99ef5d 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/RichFlowDef.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/RichFlowDef.scala @@ -43,7 +43,7 @@ class RichFlowDef(val fd: FlowDef) { */ private[scalding] def mergeMiscFrom(o: FlowDef): Unit = { // See the cascading code that this string is a "," separated set. - o.getTags.split(",").foreach(fd.addTag) + StringUtility.fastSplit(o.getTags, ",").foreach(fd.addTag) mergeLeft(fd.getTraps, o.getTraps) mergeLeft(fd.getCheckpoints, o.getCheckpoints) @@ -84,7 +84,7 @@ class RichFlowDef(val fd: FlowDef) { .foreach { oFS => FlowStateMap.mutate(fd) { current => // overwrite the items from o with current - (FlowState(oFS.sourceMap ++ current.sourceMap), ()) + (current.copy(sourceMap = oFS.sourceMap ++ current.sourceMap), ()) } } } @@ -147,7 +147,7 @@ class RichFlowDef(val fd: FlowDef) { if (headNames(name)) newfs + kv else newfs } - FlowStateMap.mutate(newFd) { _ => (FlowState(subFlowState), ()) } + FlowStateMap.mutate(newFd) { oldFS => (oldFS.copy(sourceMap = subFlowState), ()) } } newFd } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala b/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala index a867cad1e2..12f8e2c10f 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala @@ -15,22 +15,16 @@ limitations under the License. */ package com.twitter.scalding -import cascading.tap._ -import cascading.scheme._ import cascading.pipe._ -import cascading.pipe.assembly._ -import cascading.pipe.joiner._ import cascading.flow._ import cascading.operation._ -import cascading.operation.aggregator._ import cascading.operation.filter._ import cascading.tuple._ -import cascading.cascade._ -import cascading.operation.Debug.Output import scala.util.Random import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.Queue object RichPipe extends java.io.Serializable { private val nextPipe = new AtomicInteger(-1) @@ -52,8 +46,10 @@ object RichPipe extends java.io.Serializable { if (reducers > 0) { p.getStepConfigDef() .setProperty(REDUCER_KEY, reducers.toString) + p.getStepConfigDef() + .setProperty(Config.WithReducersSetExplicitly, "true") } else if (reducers != -1) { - throw new IllegalArgumentException("Number of reducers must be non-negative") + throw new IllegalArgumentException(s"Number of reducers must be non-negative. Got: ${reducers}") } p } @@ -665,6 +661,45 @@ class RichPipe(val pipe: Pipe) extends java.io.Serializable with JoinAlgorithms .flatten .toSet + /** + * This finds all the boxed serializations stored in the flow state map for this + * flowdef. We then find all the pipes back in the DAG from this pipe and apply + * those serializations. + */ + private[scalding] def applyFlowConfigProperties(flowDef: FlowDef): Pipe = { + case class ToVisit[T](queue: Queue[T], inQueue: Set[T]) { + def maybeAdd(t: T): ToVisit[T] = if (inQueue(t)) this else { + ToVisit(queue :+ t, inQueue + t) + } + def next: Option[(T, ToVisit[T])] = + if (inQueue.isEmpty) None + else Some((queue.head, ToVisit(queue.tail, inQueue - queue.head))) + } + + @annotation.tailrec + def go(p: Pipe, visited: Set[Pipe], toVisit: ToVisit[Pipe]): Set[Pipe] = { + val notSeen: Set[Pipe] = p.getPrevious.filter(i => !visited.contains(i)).toSet + val nextVisited: Set[Pipe] = visited + p + val nextToVisit = notSeen.foldLeft(toVisit) { case (prev, n) => prev.maybeAdd(n) } + + nextToVisit.next match { + case Some((h, innerNextToVisit)) => go(h, nextVisited, innerNextToVisit) + case _ => nextVisited + } + } + val allPipes = go(pipe, Set[Pipe](), ToVisit[Pipe](Queue.empty, Set.empty)) + + FlowStateMap.get(flowDef).foreach { fstm => + fstm.flowConfigUpdates.foreach { + case (k, v) => + allPipes.foreach { p => + p.getStepConfigDef().setProperty(k, v) + } + } + } + pipe + } + } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Source.scala b/scalding-core/src/main/scala/com/twitter/scalding/Source.scala index b7f98a5057..70ab2c4440 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Source.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Source.scala @@ -15,29 +15,23 @@ limitations under the License. */ package com.twitter.scalding -import java.io.{ File, InputStream, OutputStream } -import java.util.{ TimeZone, Calendar, Map => JMap, Properties } +import java.io.{ InputStream, OutputStream } +import java.util.{ Map => JMap, Properties } import cascading.flow.FlowDef import cascading.flow.FlowProcess -import cascading.flow.hadoop.HadoopFlowProcess -import cascading.flow.local.LocalFlowProcess import cascading.scheme.{ NullScheme, Scheme } import cascading.tap.hadoop.Hfs -import cascading.tap.{ MultiSourceTap, SinkMode } +import cascading.tap.SinkMode import cascading.tap.{ Tap, SinkTap } -import cascading.tap.local.FileTap import cascading.tuple.{ Fields, Tuple => CTuple, TupleEntry, TupleEntryCollector } import cascading.pipe.Pipe -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCollector; -import org.apache.hadoop.mapred.RecordReader; +import org.apache.hadoop.mapred.OutputCollector +import org.apache.hadoop.mapred.RecordReader -import collection.mutable.{ Buffer, MutableList } import scala.collection.JavaConverters._ /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala b/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala index ce55eb5ff1..56d3e203e5 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala @@ -4,7 +4,6 @@ import cascading.flow.{ FlowDef, FlowProcess } import cascading.stats.CascadingStats import java.util.concurrent.ConcurrentHashMap import org.slf4j.{ Logger, LoggerFactory } -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.ref.WeakReference @@ -108,8 +107,9 @@ object UniqueID { object RuntimeStats extends java.io.Serializable { @transient private lazy val logger: Logger = LoggerFactory.getLogger(this.getClass) - private val flowMappingStore: mutable.Map[String, WeakReference[FlowProcess[_]]] = - new ConcurrentHashMap[String, WeakReference[FlowProcess[_]]] + private val flowMappingStore: mutable.Map[String, WeakReference[FlowProcess[_]]] = { + (new ConcurrentHashMap[String, WeakReference[FlowProcess[_]]]).asScala + } def getFlowProcessForUniqueId(uniqueId: UniqueID): FlowProcess[_] = { (for { @@ -122,13 +122,21 @@ object RuntimeStats extends java.io.Serializable { } } + private[this] var prevFP: FlowProcess[_] = null def addFlowProcess(fp: FlowProcess[_]) { - val uniqueJobIdObj = fp.getProperty(UniqueID.UNIQUE_JOB_ID) - if (uniqueJobIdObj != null) { - uniqueJobIdObj.asInstanceOf[String].split(",").foreach { uniqueId => - logger.debug("Adding flow process id: " + uniqueId) - flowMappingStore.put(uniqueId, new WeakReference(fp)) + if (!(prevFP eq fp)) { + val uniqueJobIdObj = fp.getProperty(UniqueID.UNIQUE_JOB_ID) + if (uniqueJobIdObj != null) { + // for speed concern, use a while loop instead of foreach here + var splitted = StringUtility.fastSplit(uniqueJobIdObj.asInstanceOf[String], ",") + while (!splitted.isEmpty) { + val uniqueId = splitted.head + splitted = splitted.tail + logger.debug("Adding flow process id: " + uniqueId) + flowMappingStore.put(uniqueId, new WeakReference(fp)) + } } + prevFP = fp } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/StreamOperations.scala b/scalding-core/src/main/scala/com/twitter/scalding/StreamOperations.scala index e641bf1ad5..b7f2216b13 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/StreamOperations.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/StreamOperations.scala @@ -18,10 +18,6 @@ package com.twitter.scalding import cascading.tuple.Fields import cascading.tuple.{ Tuple => CTuple, TupleEntry } -import scala.collection.JavaConverters._ - -import Dsl._ //Get the conversion implicits - /** * Implements reductions on top of a simple abstraction for the Fields-API * We use the f-bounded polymorphism trick to return the type called Self diff --git a/scalding-core/src/main/scala/com/twitter/scalding/StringUtility.scala b/scalding-core/src/main/scala/com/twitter/scalding/StringUtility.scala new file mode 100644 index 0000000000..1e421e0990 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/StringUtility.scala @@ -0,0 +1,21 @@ +package com.twitter.scalding + +object StringUtility { + private def fastSplitHelper(text: String, key: String, from: Int, textLength: Int, keyLength: Int): List[String] = { + val firstIndex = text.indexOf(key, from) + if (firstIndex == -1) { + if (from < textLength) { + List(text.substring(from)) + } else { + List("") + } + } else { + // the text till the separator should be kept in any case + text.substring(from, firstIndex) :: fastSplitHelper(text, key, firstIndex + keyLength, textLength, keyLength) + } + } + + def fastSplit(text: String, key: String): List[String] = { + fastSplitHelper(text, key, 0, text.length, key.length) + } +} \ No newline at end of file diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TemplateSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/TemplateSource.scala index 9b6d889017..09d469ef50 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TemplateSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TemplateSource.scala @@ -15,13 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.OutputCollector - -import cascading.scheme.hadoop.{ TextDelimited => CHTextDelimited } -import cascading.scheme.hadoop.TextLine.Compress -import cascading.scheme.Scheme import cascading.tap.hadoop.Hfs import cascading.tap.hadoop.{ TemplateTap => HTemplateTap } import cascading.tap.local.FileTap diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TimePathedSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/TimePathedSource.scala index e2db997c0c..126a56c4b9 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TimePathedSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TimePathedSource.scala @@ -18,9 +18,6 @@ package com.twitter.scalding import java.util.TimeZone import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.JobConf object TimePathedSource { val YEAR_MONTH_DAY = "/%1$tY/%1$tm/%1$td" @@ -34,6 +31,36 @@ object TimePathedSource { "%1$tm" -> Months(1)(tz), "%1$tY" -> Years(1)(tz)) .find { unitDur: (String, Duration) => pattern.contains(unitDur._1) } .map(_._2) + + /** + * Gives all paths in the given daterange with windows based on the provided duration. + */ + def allPathsWithDuration(pattern: String, duration: Duration, dateRange: DateRange, tz: TimeZone): Iterable[String] = + // This method is exhaustive, but too expensive for Cascading's JobConf writing. + dateRange.each(duration) + .map { dr: DateRange => + toPath(pattern, dr.start, tz) + } + + /** + * Gives all read paths in the given daterange. + */ + def readPathsFor(pattern: String, dateRange: DateRange, tz: TimeZone): Iterable[String] = { + TimePathedSource.stepSize(pattern, DateOps.UTC) match { + case Some(duration) => allPathsWithDuration(pattern, duration, dateRange, DateOps.UTC) + case None => sys.error(s"No suitable step size for pattern: $pattern") + } + } + + /** + * Gives the write path based on daterange end. + */ + def writePathFor(pattern: String, dateRange: DateRange, tz: TimeZone): String = { + assert(pattern.takeRight(2) == "/*", "Pattern must end with /* " + pattern) + val lastSlashPos = pattern.lastIndexOf('/') + val stripped = pattern.slice(0, lastSlashPos) + toPath(stripped, dateRange.end, tz) + } } abstract class TimeSeqPathedSource(val patterns: Seq[String], val dateRange: DateRange, val tz: TimeZone) extends FileSource { @@ -51,15 +78,10 @@ abstract class TimeSeqPathedSource(val patterns: Seq[String], val dateRange: Dat TimePathedSource.stepSize(pattern, tz) protected def allPathsFor(pattern: String): Iterable[String] = - defaultDurationFor(pattern) - .map { dur => - // This method is exhaustive, but too expensive for Cascading's JobConf writing. - dateRange.each(dur) - .map { dr: DateRange => - TimePathedSource.toPath(pattern, dr.start, tz) - } - } - .getOrElse(Nil) + defaultDurationFor(pattern) match { + case Some(duration) => TimePathedSource.allPathsWithDuration(pattern, duration, dateRange, tz) + case None => sys.error(s"No suitable step size for pattern: $pattern") + } /** These are all the paths we will read for this data completely enumerated */ def allPaths: Iterable[String] = @@ -107,15 +129,12 @@ abstract class TimePathedSource(val pattern: String, tz: TimeZone) extends TimeSeqPathedSource(Seq(pattern), dateRange, tz) { //Write to the path defined by the end time: - override def hdfsWritePath = { - // TODO this should be required everywhere but works on read without it - // maybe in 0.9.0 be more strict - assert(pattern.takeRight(2) == "/*", "Pattern must end with /* " + pattern) - val lastSlashPos = pattern.lastIndexOf('/') - val stripped = pattern.slice(0, lastSlashPos) - TimePathedSource.toPath(stripped, dateRange.end, tz) - } - override def localPath = pattern + override def hdfsWritePath = TimePathedSource.writePathFor(pattern, dateRange, tz) + + override def localPaths = patterns + .flatMap { pattern: String => + Globifier(pattern)(tz).globify(dateRange) + } } /* diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala b/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala index d0bb880599..ad9af8449a 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala @@ -15,15 +15,13 @@ limitations under the License. */ package com.twitter.scalding -import org.apache.hadoop -import cascading.tuple.Tuple -import collection.mutable.{ ListBuffer, Buffer } +import org.apache.hadoop.conf.Configured +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.util.{ GenericOptionsParser, Tool => HTool, ToolRunner } + import scala.annotation.tailrec -import scala.util.Try -import java.io.{ BufferedWriter, File, FileOutputStream, OutputStreamWriter } -import java.util.UUID -class Tool extends hadoop.conf.Configured with hadoop.util.Tool { +class Tool extends Configured with HTool { // This mutable state is not my favorite, but we are constrained by the Hadoop API: var rootJob: Option[(Args) => Job] = None @@ -53,7 +51,7 @@ class Tool extends hadoop.conf.Configured with hadoop.util.Tool { // and returns all the non-hadoop arguments. Should be called once if // you want to process hadoop arguments (like -libjars). protected def nonHadoopArgsFrom(args: Array[String]): Array[String] = { - (new hadoop.util.GenericOptionsParser(getConf, args)).getRemainingArgs + (new GenericOptionsParser(getConf, args)).getRemainingArgs } def parseModeArgs(args: Array[String]): (Mode, Args) = { @@ -125,7 +123,7 @@ class Tool extends hadoop.conf.Configured with hadoop.util.Tool { object Tool { def main(args: Array[String]) { try { - hadoop.util.ToolRunner.run(new hadoop.mapred.JobConf, new Tool, args) + ToolRunner.run(new JobConf, new Tool, args) } catch { case t: Throwable => { //re-throw the exception with extra info diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TupleConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/TupleConversions.scala index 4e5caa7777..e52511d2af 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TupleConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TupleConversions.scala @@ -15,14 +15,5 @@ limitations under the License. */ package com.twitter.scalding -import cascading.tuple.TupleEntry -import cascading.tuple.TupleEntryIterator -import cascading.tuple.{ Tuple => CTuple } -import cascading.tuple.Tuples - -import java.io.Serializable - -import scala.collection.JavaConverters._ - @deprecated("This trait does nothing now", "0.9.0") trait TupleConversions diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TuplePacker.scala b/scalding-core/src/main/scala/com/twitter/scalding/TuplePacker.scala index efd6547f94..32262a8be0 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TuplePacker.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TuplePacker.scala @@ -15,8 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import cascading.pipe._ -import cascading.pipe.joiner._ import cascading.tuple._ import java.lang.reflect.Method diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TupleUnpacker.scala b/scalding-core/src/main/scala/com/twitter/scalding/TupleUnpacker.scala index 82fa09bb20..fbef771c81 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TupleUnpacker.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TupleUnpacker.scala @@ -15,12 +15,9 @@ limitations under the License. */ package com.twitter.scalding -import cascading.pipe._ -import cascading.pipe.joiner._ import cascading.tuple._ import scala.reflect.Manifest -import scala.collection.JavaConverters._ /** * Typeclass for objects which unpack an object into a tuple. diff --git a/scalding-core/src/main/scala/com/twitter/scalding/TypedDelimited.scala b/scalding-core/src/main/scala/com/twitter/scalding/TypedDelimited.scala index e9ea62b217..b24573e19c 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/TypedDelimited.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/TypedDelimited.scala @@ -69,7 +69,7 @@ object TypedPsv extends TypedSeperatedFile { * Typed one separated values file (commonly used by Pig) */ object TypedOsv extends TypedSeperatedFile { - val separator = "\1" + val separator = "\u0001" } object FixedPathTypedDelimited { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/bdd/BddDsl.scala b/scalding-core/src/main/scala/com/twitter/scalding/bdd/BddDsl.scala index 1ca5fe3fdd..65592dac5e 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/bdd/BddDsl.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/bdd/BddDsl.scala @@ -3,9 +3,7 @@ package com.twitter.scalding.bdd import com.twitter.scalding._ import scala.collection.mutable.Buffer import cascading.tuple.Fields -import scala.Predef._ import com.twitter.scalding.Tsv -import org.slf4j.LoggerFactory trait BddDsl extends FieldConversions with PipeOperationsConversions { def Given(source: TestSource): TestCaseGiven1 = new TestCaseGiven1(source) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TBddDsl.scala b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TBddDsl.scala index 3594ac2198..ce7440f1cb 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TBddDsl.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TBddDsl.scala @@ -3,8 +3,6 @@ package com.twitter.scalding.bdd import cascading.flow.FlowDef import com.twitter.scalding._ import scala.collection.mutable.Buffer -import cascading.tuple.Fields -import scala.Predef._ import TDsl._ trait TBddDsl extends FieldConversions with TypedPipeOperationsConversions { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala index d62c49537d..967f80486d 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala @@ -1,7 +1,7 @@ package com.twitter.scalding.bdd import com.twitter.scalding.TypedPipe -import com.twitter.scalding.{ Dsl, RichPipe } +import com.twitter.scalding.Dsl trait TypedPipeOperationsConversions { import Dsl._ diff --git a/scalding-core/src/main/scala/com/twitter/scalding/examples/WeightedPageRankFromMatrix.scala b/scalding-core/src/main/scala/com/twitter/scalding/examples/WeightedPageRankFromMatrix.scala index 852cffc3b2..32e9dee952 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/examples/WeightedPageRankFromMatrix.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/examples/WeightedPageRankFromMatrix.scala @@ -1,7 +1,5 @@ package com.twitter.scalding.examples -import scala.collection._ - import com.twitter.scalding._ import com.twitter.scalding.mathematics.{ Matrix, ColVector } import com.twitter.scalding.mathematics.Matrix._ diff --git a/scalding-core/src/main/scala/com/twitter/scalding/examples/WordCountJob.scala b/scalding-core/src/main/scala/com/twitter/scalding/examples/WordCountJob.scala index 8801642b38..83a8dd0175 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/examples/WordCountJob.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/examples/WordCountJob.scala @@ -3,8 +3,10 @@ package com.twitter.scalding.examples import com.twitter.scalding._ class WordCountJob(args: Args) extends Job(args) { - TextLine(args("input")).read. - flatMap('line -> 'word) { line: String => line.split("\\s+") }. - groupBy('word) { _.size }. - write(Tsv(args("output"))) + TypedPipe.from(TextLine(args("input"))) + .flatMap { line => line.split("\\s+") } + .map { word => (word, 1L) } + .sumByKey + // The compiler will enforce the type coming out of the sumByKey is the same as the type we have for our sink + .write(TypedTsv[(String, Long)](args("output"))) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/filecache/DistributedCacheFile.scala b/scalding-core/src/main/scala/com/twitter/scalding/filecache/DistributedCacheFile.scala index 33dbc1845a..0dbe743d28 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/filecache/DistributedCacheFile.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/filecache/DistributedCacheFile.scala @@ -75,7 +75,7 @@ object DistributedCacheFile { val hexsum = URIHasher(uri) val fileName = new File(uri.toString).getName - Seq(fileName, hexsum).mkString("-") + Seq(hexsum, fileName).mkString("-") } def symlinkedUriFor(sourceUri: URI): URI = diff --git a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala index 7f2be86a29..66dc078272 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala @@ -130,18 +130,31 @@ sealed trait Matrix2[R, C, V] extends Serializable { } /** - * Row L2 normalization (can only be called for Double) + * Row L2 normalization * After this operation, the sum(|x|^2) along each row will be 1. */ - def rowL2Normalize(implicit ev: =:=[V, Double], mj: MatrixJoiner2): Matrix2[R, C, Double] = { - val matD = this.asInstanceOf[Matrix2[R, C, Double]] - lazy val result = MatrixLiteral(matD.toTypedPipe.map { case (r, c, x) => (r, c, x * x) }, this.sizeHint) + def rowL2Normalize(implicit num: Numeric[V], mj: MatrixJoiner2): Matrix2[R, C, Double] = { + val matD = MatrixLiteral(this.toTypedPipe.map{ case (r, c, x) => (r, c, num.toDouble(x)) }, this.sizeHint) + lazy val result = MatrixLiteral(this.toTypedPipe.map { case (r, c, x) => (r, c, num.toDouble(x) * num.toDouble(x)) }, this.sizeHint) .sumColVectors .toTypedPipe .map { case (r, c, x) => (r, r, 1 / scala.math.sqrt(x)) } // diagonal + inverse MatrixLiteral(result, SizeHint.asDiagonal(this.sizeHint.setRowsToCols)) * matD } + /** + * Row L1 normalization + * After this operation, the sum(|x|) alone each row will be 1. + */ + def rowL1Normalize(implicit num: Numeric[V], mj: MatrixJoiner2): Matrix2[R, C, Double] = { + val matD = MatrixLiteral(this.toTypedPipe.map{ case (r, c, x) => (r, c, num.toDouble(x).abs) }, this.sizeHint) + lazy val result = matD + .sumColVectors + .toTypedPipe + .map { case (r, c, x) => (r, r, 1 / x) } // diagonal + inverse + MatrixLiteral(result, SizeHint.asDiagonal(this.sizeHint.setRowsToCols)) * matD + } + def getRow(index: R): Matrix2[Unit, C, V] = MatrixLiteral( toTypedPipe diff --git a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/Common.scala b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/Common.scala index 5657224737..116afe1313 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/Common.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/Common.scala @@ -2,8 +2,9 @@ package com.twitter.scalding.reducer_estimation import cascading.flow.{ FlowStep, Flow, FlowStepStrategy } import com.twitter.algebird.Monoid -import com.twitter.scalding.Config +import com.twitter.scalding.{ StringUtility, Config } import org.apache.hadoop.mapred.JobConf +import org.slf4j.LoggerFactory import java.util.{ List => JList } import scala.collection.JavaConverters._ @@ -51,6 +52,8 @@ case class FallbackEstimator(first: ReducerEstimator, fallback: ReducerEstimator object ReducerEstimatorStepStrategy extends FlowStepStrategy[JobConf] { + private val LOG = LoggerFactory.getLogger(this.getClass) + implicit val estimatorMonoid: Monoid[ReducerEstimator] = new Monoid[ReducerEstimator] { override def zero: ReducerEstimator = new ReducerEstimator override def plus(l: ReducerEstimator, r: ReducerEstimator): ReducerEstimator = @@ -67,17 +70,24 @@ object ReducerEstimatorStepStrategy extends FlowStepStrategy[JobConf] { final override def apply(flow: Flow[JobConf], preds: JList[FlowStep[JobConf]], step: FlowStep[JobConf]): Unit = { + val conf = step.getConfig + // for steps with reduce phase, mapred.reduce.tasks is set in the jobconf at this point + // so we check that to determine if this is a map-only step. + conf.getNumReduceTasks match { + case 0 => LOG.info(s"${flow.getName} is a map-only step. Skipping reducer estimation.") + case _ => estimate(flow, preds, step) + } + } - val flowNumReducers = flow.getConfig.get(Config.HadoopNumReducers) + private def estimate(flow: Flow[JobConf], + preds: JList[FlowStep[JobConf]], + step: FlowStep[JobConf]): Unit = { + val conf = step.getConfig val stepNumReducers = conf.get(Config.HadoopNumReducers) - // assuming that if the step's reducers is different than the default for the flow, - // it was probably set by `withReducers` explicitly. This isn't necessarily true -- - // Cascading may have changed it for its own reasons. - // TODO: disambiguate this by setting something in JobConf when `withReducers` is called - // (will be addressed by https://github.com/twitter/scalding/pull/973) - val setExplicitly = flowNumReducers != stepNumReducers + // whether the reducers have been set explicitly with `withReducers` + val setExplicitly = conf.getBoolean(Config.WithReducersSetExplicitly, false) // log in JobConf what was explicitly set by 'withReducers' if (setExplicitly) conf.set(EstimatorConfig.originalNumReducers, stepNumReducers) @@ -89,7 +99,7 @@ object ReducerEstimatorStepStrategy extends FlowStepStrategy[JobConf] { val clsLoader = Thread.currentThread.getContextClassLoader - val estimators = clsNames.split(",") + val estimators = StringUtility.fastSplit(clsNames, ",") .map(clsLoader.loadClass(_).newInstance.asInstanceOf[ReducerEstimator]) val combinedEstimator = Monoid.sum(estimators) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/InputSizeReducerEstimator.scala b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/InputSizeReducerEstimator.scala index edfce5c246..221451c3a0 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/InputSizeReducerEstimator.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/InputSizeReducerEstimator.scala @@ -2,7 +2,7 @@ package com.twitter.scalding.reducer_estimation import scala.collection.JavaConverters._ import cascading.flow.FlowStep -import cascading.tap.{ Tap, MultiSourceTap } +import cascading.tap.{ Tap, CompositeTap } import cascading.tap.hadoop.Hfs import org.apache.hadoop.mapred.JobConf import org.slf4j.LoggerFactory @@ -26,7 +26,7 @@ class InputSizeReducerEstimator extends ReducerEstimator { private def unrollTaps(taps: Seq[Tap[_, _, _]]): Seq[Tap[_, _, _]] = taps.flatMap { - case multi: MultiSourceTap[_, _, _] => + case multi: CompositeTap[_] => unrollTaps(multi.getChildTaps.asScala.toSeq) case t => Seq(t) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/AlgebirdSerializers.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/AlgebirdSerializers.scala deleted file mode 100644 index 8b749123e2..0000000000 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/AlgebirdSerializers.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* -Copyright 2012 Twitter, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package com.twitter.scalding.serialization - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.{ Serializer => KSerializer } -import com.esotericsoftware.kryo.io.{ Input, Output } - -import com.twitter.algebird.{ - AveragedValue, - DecayedValue, - HLL, - HyperLogLog, - HyperLogLogMonoid, - Moments -} - -import scala.collection.mutable.{ Map => MMap } - -class AveragedValueSerializer extends KSerializer[AveragedValue] { - setImmutable(true) - def write(kser: Kryo, out: Output, s: AveragedValue) { - out.writeLong(s.count, true) - out.writeDouble(s.value) - } - def read(kser: Kryo, in: Input, cls: Class[AveragedValue]): AveragedValue = - AveragedValue(in.readLong(true), in.readDouble) -} - -class MomentsSerializer extends KSerializer[Moments] { - setImmutable(true) - def write(kser: Kryo, out: Output, s: Moments) { - out.writeLong(s.m0, true) - out.writeDouble(s.m1) - out.writeDouble(s.m2) - out.writeDouble(s.m3) - out.writeDouble(s.m4) - } - def read(kser: Kryo, in: Input, cls: Class[Moments]): Moments = { - Moments(in.readLong(true), - in.readDouble, - in.readDouble, - in.readDouble, - in.readDouble) - } -} - -class DecayedValueSerializer extends KSerializer[DecayedValue] { - setImmutable(true) - def write(kser: Kryo, out: Output, s: DecayedValue) { - out.writeDouble(s.value) - out.writeDouble(s.scaledTime) - } - def read(kser: Kryo, in: Input, cls: Class[DecayedValue]): DecayedValue = - DecayedValue(in.readDouble, in.readDouble) -} - -class HLLSerializer extends KSerializer[HLL] { - setImmutable(true) - def write(kser: Kryo, out: Output, s: HLL) { - val bytes = HyperLogLog.toBytes(s) - out.writeInt(bytes.size, true) - out.writeBytes(bytes) - } - def read(kser: Kryo, in: Input, cls: Class[HLL]): HLL = { - HyperLogLog.fromBytes(in.readBytes(in.readInt(true))) - } -} - -class HLLMonoidSerializer extends KSerializer[HyperLogLogMonoid] { - setImmutable(true) - val hllMonoids = MMap[Int, HyperLogLogMonoid]() - def write(kser: Kryo, out: Output, mon: HyperLogLogMonoid) { - out.writeInt(mon.bits, true) - } - def read(kser: Kryo, in: Input, cls: Class[HyperLogLogMonoid]): HyperLogLogMonoid = { - val bits = in.readInt(true) - hllMonoids.getOrElseUpdate(bits, new HyperLogLogMonoid(bits)) - } -} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala new file mode 100644 index 0000000000..abed86b59b --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala @@ -0,0 +1,85 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +import java.io.InputStream +import java.util.Comparator +import cascading.flow.FlowDef +import cascading.tuple.{ Hasher => CHasher, StreamComparator } + +import scala.util.{ Failure, Success, Try } + +/** + * This is the type that should be fed to cascading to enable binary comparators + */ +class CascadingBinaryComparator[T](ob: OrderedSerialization[T]) extends Comparator[T] + with StreamComparator[InputStream] + with CHasher[T] + with Serializable { + + override def compare(a: T, b: T) = ob.compare(a, b) + override def hashCode(t: T): Int = ob.hash(t) + override def compare(a: InputStream, b: InputStream) = + ob.compareBinary(a, b).unsafeToInt +} + +object CascadingBinaryComparator { + /** + * This method will walk the flowDef and make sure all the + * groupBy/cogroups are using a CascadingBinaryComparator + */ + def checkForOrderedSerialization(fd: FlowDef): Try[Unit] = { + import collection.JavaConverters._ + import cascading.pipe._ + import com.twitter.scalding.RichPipe + + // all successes or empty returns success + def reduce(it: TraversableOnce[Try[Unit]]): Try[Unit] = + it.find(_.isFailure).getOrElse(Success(())) + + def check(s: Splice): Try[Unit] = { + val m = s.getKeySelectors.asScala + + def error(s: String): Try[Unit] = + Failure(new RuntimeException("Cannot verify OrderedSerialization: " + s)) + + if (m.isEmpty) error(s"Splice must have KeySelectors: $s") + else { + reduce(m.map { + case (pipename, fields) => + /* + * Scalding typed-API ALWAYS puts the key into field position 0. + * If OrderedSerialization is enabled, this must be a CascadingBinaryComparator + */ + if (fields.getComparators()(0).isInstanceOf[CascadingBinaryComparator[_]]) + Success(()) + else error(s"pipe: $s, fields: $fields, comparators: ${fields.getComparators.toList}") + }) + } + } + + val allPipes: Set[Pipe] = fd.getTails.asScala.map(p => RichPipe(p).upstreamPipes).flatten.toSet + reduce(allPipes.iterator.map { + /* + * There are only two cascading primitives used by scalding that do key-sorting: + */ + case gb: GroupBy => check(gb) + case cg: CoGroup => check(cg) + case _ => Success(()) + }) + } +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/Externalizer.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/Externalizer.scala index 0946454c12..dbe5ca4826 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/Externalizer.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/Externalizer.scala @@ -14,12 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ package com.twitter.scalding.serialization + import com.twitter.chill.{ Externalizer => ChillExtern } import com.esotericsoftware.kryo.DefaultSerializer import com.esotericsoftware.kryo.serializers.JavaSerializer import com.twitter.chill.config.ScalaAnyRefMapConfig + /** * We need to control the Kryo created */ diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoHadoop.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoHadoop.scala index 97fc2055d3..5f27104ee2 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoHadoop.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoHadoop.scala @@ -15,34 +15,18 @@ limitations under the License. */ package com.twitter.scalding.serialization -import java.io.InputStream -import java.io.OutputStream -import java.io.Serializable -import java.nio.ByteBuffer - -import org.apache.hadoop.io.serializer.{ Serialization, Deserializer, Serializer, WritableSerialization } - import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.{ Serializer => KSerializer } -import com.esotericsoftware.kryo.io.{ Input, Output } import com.esotericsoftware.kryo.serializers.FieldSerializer -import cascading.tuple.hadoop.TupleSerialization -import cascading.tuple.hadoop.io.BufferedInputStream - -import scala.annotation.tailrec -import scala.collection.immutable.ListMap -import scala.collection.immutable.HashMap - import com.twitter.scalding.DateRange import com.twitter.scalding.RichDate import com.twitter.scalding.Args -import com.twitter.chill._ +import com.twitter.chill.algebird._ import com.twitter.chill.config.Config +import com.twitter.chill.{ SingletonSerializer, ScalaKryoInstantiator, KryoInstantiator } class KryoHadoop(config: Config) extends KryoInstantiator { - /** * TODO!!! * Deal with this issue. The problem is grouping by Kryo serialized @@ -66,7 +50,8 @@ class KryoHadoop(config: Config) extends KryoInstantiator { newK.register(classOf[com.twitter.algebird.HyperLogLogMonoid], new HLLMonoidSerializer) newK.register(classOf[com.twitter.algebird.Moments], new MomentsSerializer) newK.addDefaultSerializer(classOf[com.twitter.algebird.HLL], new HLLSerializer) - + // Don't serialize Boxed instances using Kryo. + newK.addDefaultSerializer(classOf[com.twitter.scalding.serialization.Boxed[_]], new ThrowingSerializer) /** * AdaptiveVector is IndexedSeq, which picks up the chill IndexedSeq serializer * (which is its own bug), force using the fields serializer here diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoSerializers.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoSerializers.scala index 4e20dc59f6..859d223431 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoSerializers.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/KryoSerializers.scala @@ -15,23 +15,23 @@ limitations under the License. */ package com.twitter.scalding.serialization -import java.io.InputStream -import java.io.OutputStream -import java.io.Serializable -import java.nio.ByteBuffer - -import org.apache.hadoop.io.serializer.{ Serialization, Deserializer, Serializer, WritableSerialization } - import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.{ Serializer => KSerializer } import com.esotericsoftware.kryo.io.{ Input, Output } -import scala.annotation.tailrec -import scala.collection.immutable.ListMap -import scala.collection.mutable.{ Map => MMap } - import com.twitter.scalding._ +/** + * This is a runtime check for types we should never be serializing + */ +class ThrowingSerializer[T] extends KSerializer[T] { + override def write(kryo: Kryo, output: Output, t: T) { + sys.error(s"Kryo should never be used to serialize an instance: $t") + } + override def read(kryo: Kryo, input: Input, t: Class[T]): T = + sys.error("Kryo should never be used to serialize an instance, class: $t") +} + /** * * * Below are some serializers for objects in the scalding project. diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/WrappedSerialization.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/WrappedSerialization.scala new file mode 100644 index 0000000000..961a09f909 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/WrappedSerialization.scala @@ -0,0 +1,122 @@ +/* +Copyright 2014 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import org.apache.hadoop.io.serializer.{ Serialization => HSerialization, Deserializer, Serializer } +import org.apache.hadoop.conf.{ Configurable, Configuration } + +import java.io.{ InputStream, OutputStream } +import com.twitter.bijection.{ Injection, JavaSerializationInjection, Base64String } +import scala.collection.JavaConverters._ + +/** + * WrappedSerialization wraps a value in a wrapper class that + * has an associated Binary that is used to deserialize + * items wrapped in the wrapper + */ +class WrappedSerialization[T] extends HSerialization[T] with Configurable { + + import WrappedSerialization.ClassSerialization + + private var conf: Option[Configuration] = None + private var serializations: Map[Class[_], Serialization[_]] = Map.empty + + override def getConf: Configuration = conf.get + override def setConf(config: Configuration) { + conf = Some(config) + serializations = WrappedSerialization.getBinary(config) + } + + def accept(c: Class[_]): Boolean = serializations.contains(c) + + def getSerialization(c: Class[T]): Option[Serialization[T]] = + serializations.get(c) + // This cast should never fail since we matched the class + .asInstanceOf[Option[Serialization[T]]] + + def getSerializer(c: Class[T]): Serializer[T] = + new BinarySerializer(getSerialization(c) + .getOrElse(sys.error(s"Serialization for class: ${c} not found"))) + + def getDeserializer(c: Class[T]): Deserializer[T] = + new BinaryDeserializer(getSerialization(c) + .getOrElse(sys.error(s"Serialization for class: ${c} not found"))) + +} + +class BinarySerializer[T](buf: Serialization[T]) extends Serializer[T] { + private var out: OutputStream = _ + def open(os: OutputStream): Unit = { + out = os + } + def close(): Unit = { out = null } + def serialize(t: T): Unit = { + if (out == null) throw new NullPointerException("OutputStream is null") + buf.write(out, t).get + } +} + +class BinaryDeserializer[T](buf: Serialization[T]) extends Deserializer[T] { + private var is: InputStream = _ + def open(i: InputStream): Unit = { is = i } + def close(): Unit = { is = null } + def deserialize(t: T): T = { + if (is == null) throw new NullPointerException("InputStream is null") + buf.read(is).get + } +} + +object WrappedSerialization { + type ClassSerialization[T] = (Class[T], Serialization[T]) + + private def getSerializer[U]: Injection[Externalizer[U], String] = { + implicit val initialInj = JavaSerializationInjection[Externalizer[U]] + Injection.connect[Externalizer[U], Array[Byte], Base64String, String] + } + + private def serialize[T](b: T): String = + getSerializer[T](Externalizer(b)) + + private def deserialize[T](str: String): T = + getSerializer[T].invert(str).get.get + + private val confKey = "com.twitter.scalding.serialization.WrappedSerialization" + + def rawSetBinary(bufs: Iterable[ClassSerialization[_]], fn: (String, String) => Unit) = { + fn(confKey, bufs.map { case (cls, buf) => s"${cls.getName}:${serialize(buf)}" }.mkString(",")) + } + def setBinary(conf: Configuration, bufs: Iterable[ClassSerialization[_]]): Unit = + rawSetBinary(bufs, { case (k, v) => conf.set(k, v) }) + + def getBinary(conf: Configuration): Map[Class[_], Serialization[_]] = + conf + .iterator + .asScala + .map { it => + (it.getKey, it.getValue) + } + .filter(_._1.startsWith(confKey)) + .map { + case (_, clsbuf) => + clsbuf.split(":") match { + case Array(className, serialization) => + // Jump through a hoop to get scalac happy + def deser[T](cls: Class[T]): ClassSerialization[T] = (cls, deserialize[Serialization[T]](serialization)) + deser(conf.getClassByName(className)) + case _ => sys.error(s"ill formed bufferables: ${clsbuf}") + } + }.toMap +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/source/CodecSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/source/CodecSource.scala index 72e963f427..394d7b7d10 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/source/CodecSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/source/CodecSource.scala @@ -56,10 +56,8 @@ class CodecSource[T] private (val hdfsPaths: Seq[String], val maxFailures: Int = lazy val field = new Fields(fieldSym.name) val injectionBox = Externalizer(injection andThen BytesWritableCodec.get) - def localPath = { - require(hdfsPaths.size == 1, "Local mode only supports a single path"); - hdfsPaths(0) - } + def localPaths = hdfsPaths + override def converter[U >: T] = TupleConverter.asSuperConverter[T, U](TupleConverter.singleConverter[T]) override def hdfsScheme = HadoopSchemeInstance(new WritableSequenceFile(field, classOf[BytesWritable]).asInstanceOf[Scheme[_, _, _, _, _]]) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala index 3f654f63bc..91400fd8be 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala @@ -204,70 +204,87 @@ trait CoGrouped[K, +R] extends KeyedListLike[K, R, CoGrouped] with CoGroupable[K val ord = keyOrdering TypedPipeFactory({ (flowDef, mode) => - val newPipe = if (firstCount == inputs.size) { - /** - * This is a self-join - * Cascading handles this by sending the data only once, spilling to disk if - * the groups don't fit in RAM, then doing the join on this one set of data. - * This is fundamentally different than the case where the first item is - * not repeated. That case is below - */ - val NUM_OF_SELF_JOINS = firstCount - 1 - new CoGroup(assignName(inputs.head.toPipe[(Any, Any)](("key", "value"))(flowDef, mode, tup2Setter)), - RichFields(StringField("key")(ord, None)), - NUM_OF_SELF_JOINS, - outFields(firstCount), - new DistinctCoGroupJoiner(firstCount, joinFunction)) - } else if (firstCount == 1) { - /** - * As long as the first one appears only once, we can handle self joins on the others: - * Cascading does this by maybe spilling all the streams other than the first item. - * This is handled by a different CoGroup constructor than the above case. - */ - def renamePipe(idx: Int, p: TypedPipe[(K, Any)]): Pipe = - p.toPipe[(K, Any)](List("key%d".format(idx), "value%d".format(idx)))(flowDef, mode, tup2Setter) - - // This is tested for the properties we need (non-reordering) - val distincts = CoGrouped.distinctBy(inputs)(identity) - val dsize = distincts.size - val isize = inputs.size - - val groupFields: Array[Fields] = (0 until dsize) - .map { idx => RichFields(StringField("key%d".format(idx))(ord, None)) } - .toArray - - val pipes: Array[Pipe] = distincts - .zipWithIndex - .map { case (item, idx) => assignName(renamePipe(idx, item)) } - .toArray - - val cjoiner = if (isize != dsize) { - // avoid capturing anything other than the mapping ints: - val mapping: Map[Int, Int] = inputs.zipWithIndex.map { - case (item, idx) => - idx -> distincts.indexWhere(_ == item) - }.toMap - - new CoGroupedJoiner(isize, joinFunction) { - val distinctSize = dsize - def distinctIndexOf(orig: Int) = mapping(orig) + val newPipe = Grouped.maybeBox[K, Any](ord, flowDef) { (tupset, ordKeyField) => + if (firstCount == inputs.size) { + /** + * This is a self-join + * Cascading handles this by sending the data only once, spilling to disk if + * the groups don't fit in RAM, then doing the join on this one set of data. + * This is fundamentally different than the case where the first item is + * not repeated. That case is below + */ + val NUM_OF_SELF_JOINS = firstCount - 1 + new CoGroup(assignName(inputs.head.toPipe[(K, Any)](("key", "value"))(flowDef, mode, + tupset)), + ordKeyField, + NUM_OF_SELF_JOINS, + outFields(firstCount), + WrappedJoiner(new DistinctCoGroupJoiner(firstCount, Grouped.keyGetter(ord), joinFunction))) + } else if (firstCount == 1) { + + def keyId(idx: Int): String = "key%d".format(idx) + /** + * As long as the first one appears only once, we can handle self joins on the others: + * Cascading does this by maybe spilling all the streams other than the first item. + * This is handled by a different CoGroup constructor than the above case. + */ + def renamePipe(idx: Int, p: TypedPipe[(K, Any)]): Pipe = + p.toPipe[(K, Any)](List(keyId(idx), "value%d".format(idx)))(flowDef, mode, + tupset) + + // This is tested for the properties we need (non-reordering) + val distincts = CoGrouped.distinctBy(inputs)(identity) + val dsize = distincts.size + val isize = inputs.size + + def makeFields(id: Int): Fields = { + val comp = ordKeyField.getComparators.apply(0) + val fieldName = keyId(id) + val f = new Fields(fieldName) + f.setComparator(fieldName, comp) + f } - } else new DistinctCoGroupJoiner(isize, joinFunction) - new CoGroup(pipes, groupFields, outFields(dsize), cjoiner) - } else { - /** - * This is non-trivial to encode in the type system, so we throw this exception - * at the planning phase. - */ - sys.error("Except for self joins, where you are joining something with only itself,\n" + - "left-most pipe can only appear once. Firsts: " + - inputs.collect { case x if x == inputs.head => x }.toString) + val groupFields: Array[Fields] = (0 until dsize) + .map(makeFields) + .toArray + + val pipes: Array[Pipe] = distincts + .zipWithIndex + .map { case (item, idx) => assignName(renamePipe(idx, item)) } + .toArray + + val cjoiner = if (isize != dsize) { + // avoid capturing anything other than the mapping ints: + val mapping: Map[Int, Int] = inputs.zipWithIndex.map { + case (item, idx) => + idx -> distincts.indexWhere(_ == item) + }.toMap + + new CoGroupedJoiner(isize, Grouped.keyGetter(ord), joinFunction) { + val distinctSize = dsize + def distinctIndexOf(orig: Int) = mapping(orig) + } + } else { + new DistinctCoGroupJoiner(isize, Grouped.keyGetter(ord), joinFunction) + } + + new CoGroup(pipes, groupFields, outFields(dsize), WrappedJoiner(cjoiner)) + } else { + /** + * This is non-trivial to encode in the type system, so we throw this exception + * at the planning phase. + */ + sys.error("Except for self joins, where you are joining something with only itself,\n" + + "left-most pipe can only appear once. Firsts: " + + inputs.collect { case x if x == inputs.head => x }.toString) + } } /* * the CoGrouped only populates the first two fields, the second two * are null. We then project out at the end of the method. */ + val pipeWithRed = RichPipe.setReducers(newPipe, reducers.getOrElse(-1)).project('key, 'value) //Construct the new TypedPipe TypedPipe.from[(K, R)](pipeWithRed, ('key, 'value))(flowDef, mode, tuple2Converter) @@ -275,7 +292,7 @@ trait CoGrouped[K, +R] extends KeyedListLike[K, R, CoGrouped] with CoGroupable[K } } -abstract class CoGroupedJoiner[K](inputSize: Int, joinFunction: (K, Iterator[CTuple], Seq[Iterable[CTuple]]) => Iterator[Any]) extends CJoiner { +abstract class CoGroupedJoiner[K](inputSize: Int, getter: TupleGetter[K], joinFunction: (K, Iterator[CTuple], Seq[Iterable[CTuple]]) => Iterator[Any]) extends CJoiner { val distinctSize: Int def distinctIndexOf(originalPos: Int): Int @@ -288,11 +305,10 @@ abstract class CoGroupedJoiner[K](inputSize: Int, joinFunction: (K, Iterator[CTu override def getIterator(jc: JoinerClosure) = { val iters = (0 until distinctSize).map { jc.getIterator(_).asScala.buffered } - val key = iters + val keyTuple = iters .collectFirst { case iter if iter.nonEmpty => iter.head } .get // One of these must have a key - .getObject(0) - .asInstanceOf[K] + val key = getter.get(keyTuple, 0) val leftMost = iters.head @@ -315,8 +331,9 @@ abstract class CoGroupedJoiner[K](inputSize: Int, joinFunction: (K, Iterator[CTu // If all the input pipes are unique, this works: class DistinctCoGroupJoiner[K](count: Int, + getter: TupleGetter[K], joinFunction: (K, Iterator[CTuple], Seq[Iterable[CTuple]]) => Iterator[Any]) - extends CoGroupedJoiner[K](count, joinFunction) { + extends CoGroupedJoiner[K](count, getter, joinFunction) { val distinctSize = count def distinctIndexOf(idx: Int) = idx } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala index c25f0f7ca9..f60184b0ac 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala @@ -23,12 +23,22 @@ import com.twitter.scalding.TupleConverter.tuple2Converter import com.twitter.scalding.TupleSetter.tup2Setter import com.twitter.scalding._ +import com.twitter.scalding.serialization.{ + Boxed, + BoxedOrderedSerialization, + CascadingBinaryComparator, + OrderedSerialization, + WrappedSerialization +} import cascading.flow.FlowDef import cascading.pipe.Pipe -import cascading.tuple.Fields - +import cascading.property.ConfigDef +import cascading.tuple.{ Fields, Tuple => CTuple } +import java.util.Comparator import scala.collection.JavaConverters._ +import scala.util.Try +import scala.collection.immutable.Queue import Dsl._ @@ -77,15 +87,59 @@ object Grouped { def apply[K, V](pipe: TypedPipe[(K, V)])(implicit ordering: Ordering[K]): Grouped[K, V] = IdentityReduce(ordering, pipe, None) - def keySorting[T](ord: Ordering[T]): Fields = sorting("key", ord) - def valueSorting[T](implicit ord: Ordering[T]): Fields = sorting("value", ord) + def valueSorting[V](ord: Ordering[V]): Fields = Field.singleOrdered[V]("value")(ord) - def sorting[T](key: String, ord: Ordering[T]): Fields = { - val f = new Fields(key) - f.setComparator(key, ord) - f + /** + * If we are using OrderedComparable, we need to box the key + * to prevent other serializers from handling the key + */ + private[scalding] def maybeBox[K, V](ord: Ordering[K], flowDef: FlowDef)(op: (TupleSetter[(K, V)], Fields) => Pipe): Pipe = ord match { + case ordser: OrderedSerialization[K] => + val (boxfn, cls) = Boxed.next[K] + val boxordSer = BoxedOrderedSerialization(boxfn, ordser) + + WrappedSerialization.rawSetBinary(List((cls, boxordSer)), + { + case (k: String, v: String) => + FlowStateMap.mutate(flowDef) { st => + val newSt = st.addConfigSetting(k + cls, v) + (newSt, ()) + } + }) + + val ts = tup2Setter[(Boxed[K], V)].contraMap { kv1: (K, V) => (boxfn(kv1._1), kv1._2) } + val keyF = new Fields("key") + keyF.setComparator("key", new CascadingBinaryComparator(boxordSer)) + op(ts, keyF) + case _ => + val ts = tup2Setter[(K, V)] + val keyF = Field.singleOrdered("key")(ord) + op(ts, keyF) } + def tuple2Conv[K, V](ord: Ordering[K]): TupleConverter[(K, V)] = + ord match { + case _: OrderedSerialization[_] => + tuple2Converter[Boxed[K], V].andThen { kv => + (kv._1.get, kv._2) + } + case _ => tuple2Converter[K, V] + } + def keyConverter[K](ord: Ordering[K]): TupleConverter[K] = + ord match { + case _: OrderedSerialization[_] => + TupleConverter.singleConverter[Boxed[K]].andThen(_.get) + case _ => TupleConverter.singleConverter[K] + } + def keyGetter[K](ord: Ordering[K]): TupleGetter[K] = + ord match { + case _: OrderedSerialization[K] => + new TupleGetter[K] { + def get(tup: CTuple, i: Int) = tup.getObject(i).asInstanceOf[Boxed[K]].get + } + case _ => TupleGetter.castingGetter + } + def addEmptyGuard[K, V1, V2](fn: (K, Iterator[V1]) => Iterator[V2]): (K, Iterator[V1]) => Iterator[V2] = { (key: K, iter: Iterator[V1]) => if (iter.nonEmpty) fn(key, iter) else Iterator.empty } @@ -126,14 +180,15 @@ sealed trait ReduceStep[K, V1] extends KeyedPipe[K] { */ def mapped: TypedPipe[(K, V1)] // make the pipe and group it, only here because it is common - protected def groupOp[V2](gb: GroupBuilder => GroupBuilder): TypedPipe[(K, V2)] = { + protected def groupOp[V2](gb: GroupBuilder => GroupBuilder): TypedPipe[(K, V2)] = TypedPipeFactory({ (fd, mode) => - val reducedPipe = mapped - .toPipe(Grouped.kvFields)(fd, mode, tup2Setter) - .groupBy(Grouped.keySorting(keyOrdering))(gb) - TypedPipe.from(reducedPipe, Grouped.kvFields)(fd, mode, tuple2Converter[K, V2]) + val pipe = Grouped.maybeBox[K, V1](keyOrdering, fd) { (tupleSetter, fields) => + mapped + .toPipe(Grouped.kvFields)(fd, mode, tupleSetter) + .groupBy(fields)(gb) + } + TypedPipe.from(pipe, Grouped.kvFields)(fd, mode, Grouped.tuple2Conv[K, V2](keyOrdering)) }) - } } case class IdentityReduce[K, V1]( @@ -370,7 +425,7 @@ case class ValueSortedReduce[K, V1, V2]( groupOp { _.sortBy(vSort) .every(new cascading.pipe.Every(_, Grouped.valueField, - new TypedBufferOp(reduceFn, Grouped.valueField), Fields.REPLACE)) + new TypedBufferOp(Grouped.keyConverter(keyOrdering), reduceFn, Grouped.valueField), Fields.REPLACE)) .reducers(reducers.getOrElse(-1)) } } @@ -409,7 +464,7 @@ case class IteratorMappedReduce[K, V1, V2]( override lazy val toTypedPipe = groupOp { _.every(new cascading.pipe.Every(_, Grouped.valueField, - new TypedBufferOp(reduceFn, Grouped.valueField), Fields.REPLACE)) + new TypedBufferOp(Grouped.keyConverter(keyOrdering), reduceFn, Grouped.valueField), Fields.REPLACE)) .reducers(reducers.getOrElse(-1)) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/HashEqualsArrayWrapper.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/HashEqualsArrayWrapper.scala new file mode 100644 index 0000000000..c5eabbb05d --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/HashEqualsArrayWrapper.scala @@ -0,0 +1,319 @@ +package com.twitter.scalding.typed + +import java.util + +import reflect.ClassTag + +sealed trait HashEqualsArrayWrapper[T] { + def wrapped: Array[T] +} + +object HashEqualsArrayWrapper { + + // gross way to make specialized wrappers for primitives + // relies on the fact that Array generic types are not erased + def wrap[T](arr: Array[T]): HashEqualsArrayWrapper[T] = + wrapFn[T, Array[T]](arr.getClass.asInstanceOf[Class[Array[T]]])(arr) + + def wrapFn[T, A <: Array[T]](clazz: Class[A]): A => HashEqualsArrayWrapper[T] = { + val fn = clazz match { + case c if classOf[Array[Long]].equals(c) => a: Array[Long] => new HashEqualsLongArrayWrapper(a) + case c if classOf[Array[Int]].equals(c) => a: Array[Int] => new HashEqualsIntArrayWrapper(a) + case c if classOf[Array[Short]].equals(c) => a: Array[Short] => new HashEqualsShortArrayWrapper(a) + case c if classOf[Array[Char]].equals(c) => a: Array[Char] => new HashEqualsCharArrayWrapper(a) + case c if classOf[Array[Byte]].equals(c) => a: Array[Byte] => new HashEqualsByteArrayWrapper(a) + case c if classOf[Array[Boolean]].equals(c) => a: Array[Boolean] => new HashEqualsBooleanArrayWrapper(a) + case c if classOf[Array[Float]].equals(c) => a: Array[Float] => new HashEqualsFloatArrayWrapper(a) + case c if classOf[Array[Double]].equals(c) => a: Array[Double] => new HashEqualsDoubleArrayWrapper(a) + case c => a: Array[T] => new HashEqualsObjectArrayWrapper(a) + } + + fn.asInstanceOf[(Array[T] => HashEqualsArrayWrapper[T])] + } + + def wrapFn[T: ClassTag]: Array[T] => HashEqualsArrayWrapper[T] = + wrapFn(scala.reflect.classTag[T].runtimeClass.asInstanceOf[Class[Array[T]]]) + + implicit val longArrayOrd: Ordering[Array[Long]] = new Ordering[Array[Long]] { + override def compare(x: Array[Long], y: Array[Long]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Long.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Long.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val intArrayOrd: Ordering[Array[Int]] = new Ordering[Array[Int]] { + override def compare(x: Array[Int], y: Array[Int]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Integer.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Integer.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val shortArrayOrd: Ordering[Array[Short]] = new Ordering[Array[Short]] { + override def compare(x: Array[Short], y: Array[Short]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Short.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Short.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val charArrayOrd: Ordering[Array[Char]] = new Ordering[Array[Char]] { + override def compare(x: Array[Char], y: Array[Char]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Character.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Character.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val byteArrayOrd: Ordering[Array[Byte]] = new Ordering[Array[Byte]] { + override def compare(x: Array[Byte], y: Array[Byte]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Byte.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Byte.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val booleanArrayOrd: Ordering[Array[Boolean]] = new Ordering[Array[Boolean]] { + override def compare(x: Array[Boolean], y: Array[Boolean]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Boolean.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Boolean.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val floatArrayOrd: Ordering[Array[Float]] = new Ordering[Array[Float]] { + override def compare(x: Array[Float], y: Array[Float]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Float.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Float.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val doubleArrayOrd: Ordering[Array[Double]] = new Ordering[Array[Double]] { + override def compare(x: Array[Double], y: Array[Double]): Int = { + val lenCmp = java.lang.Integer.compare(x.length, y.length) + + if (lenCmp != 0) { + lenCmp + } else if (x.length == 0) { + 0 + } else { + val len = x.length + var i = 1 + var cmp = java.lang.Double.compare(x(0), y(0)) + while (i < len && cmp == 0) { + cmp = java.lang.Double.compare(x(i), y(i)) + i = i + 1 + } + cmp + } + } + } + + implicit val hashEqualsLongOrdering: Ordering[HashEqualsArrayWrapper[Long]] = new Ordering[HashEqualsArrayWrapper[Long]] { + override def compare(x: HashEqualsArrayWrapper[Long], y: HashEqualsArrayWrapper[Long]): Int = + longArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsIntOrdering: Ordering[HashEqualsArrayWrapper[Int]] = new Ordering[HashEqualsArrayWrapper[Int]] { + override def compare(x: HashEqualsArrayWrapper[Int], y: HashEqualsArrayWrapper[Int]): Int = + intArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsShortOrdering: Ordering[HashEqualsArrayWrapper[Short]] = new Ordering[HashEqualsArrayWrapper[Short]] { + override def compare(x: HashEqualsArrayWrapper[Short], y: HashEqualsArrayWrapper[Short]): Int = + shortArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsCharOrdering: Ordering[HashEqualsArrayWrapper[Char]] = new Ordering[HashEqualsArrayWrapper[Char]] { + override def compare(x: HashEqualsArrayWrapper[Char], y: HashEqualsArrayWrapper[Char]): Int = + charArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsByteOrdering: Ordering[HashEqualsArrayWrapper[Byte]] = new Ordering[HashEqualsArrayWrapper[Byte]] { + override def compare(x: HashEqualsArrayWrapper[Byte], y: HashEqualsArrayWrapper[Byte]): Int = + byteArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsBooleanOrdering: Ordering[HashEqualsArrayWrapper[Boolean]] = new Ordering[HashEqualsArrayWrapper[Boolean]] { + override def compare(x: HashEqualsArrayWrapper[Boolean], y: HashEqualsArrayWrapper[Boolean]): Int = + booleanArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsFloatOrdering: Ordering[HashEqualsArrayWrapper[Float]] = new Ordering[HashEqualsArrayWrapper[Float]] { + override def compare(x: HashEqualsArrayWrapper[Float], y: HashEqualsArrayWrapper[Float]): Int = + floatArrayOrd.compare(x.wrapped, y.wrapped) + } + + implicit val hashEqualsDoubleOrdering: Ordering[HashEqualsArrayWrapper[Double]] = new Ordering[HashEqualsArrayWrapper[Double]] { + override def compare(x: HashEqualsArrayWrapper[Double], y: HashEqualsArrayWrapper[Double]): Int = + doubleArrayOrd.compare(x.wrapped, y.wrapped) + } + +} + +final class HashEqualsLongArrayWrapper(override val wrapped: Array[Long]) extends HashEqualsArrayWrapper[Long] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsLongArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsIntArrayWrapper(override val wrapped: Array[Int]) extends HashEqualsArrayWrapper[Int] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsIntArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsShortArrayWrapper(override val wrapped: Array[Short]) extends HashEqualsArrayWrapper[Short] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsShortArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsCharArrayWrapper(override val wrapped: Array[Char]) extends HashEqualsArrayWrapper[Char] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsCharArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsByteArrayWrapper(override val wrapped: Array[Byte]) extends HashEqualsArrayWrapper[Byte] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsByteArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsBooleanArrayWrapper(override val wrapped: Array[Boolean]) extends HashEqualsArrayWrapper[Boolean] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsBooleanArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsFloatArrayWrapper(override val wrapped: Array[Float]) extends HashEqualsArrayWrapper[Float] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsFloatArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsDoubleArrayWrapper(override val wrapped: Array[Double]) extends HashEqualsArrayWrapper[Double] { + override def hashCode(): Int = util.Arrays.hashCode(wrapped) + + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsDoubleArrayWrapper => util.Arrays.equals(wrapped, other.wrapped) + case _ => false + } +} + +final class HashEqualsObjectArrayWrapper[T](override val wrapped: Array[T]) extends HashEqualsArrayWrapper[T] { + private val wrappedInternal = wrapped.toSeq + override def hashCode(): Int = wrappedInternal.hashCode() + override def equals(obj: scala.Any): Boolean = obj match { + case other: HashEqualsObjectArrayWrapper[T] => wrappedInternal.equals(other.wrappedInternal) + case _ => false + } +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/HashJoinable.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/HashJoinable.scala index 940b1a96fb..37764cb3ba 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/HashJoinable.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/HashJoinable.scala @@ -49,10 +49,10 @@ trait HashJoinable[K, +V] extends CoGroupable[K, V] with KeyedPipe[K] { TypedPipeFactory({ (fd, mode) => val newPipe = new HashJoin( RichPipe.assignName(mapside.toPipe(('key, 'value))(fd, mode, tup2Setter)), - RichFields(StringField("key")(keyOrdering, None)), + Field.singleOrdered("key")(keyOrdering), mapped.toPipe(('key1, 'value1))(fd, mode, tup2Setter), - RichFields(StringField("key1")(keyOrdering, None)), - new HashJoiner(joinFunction, joiner)) + Field.singleOrdered("key1")(keyOrdering), + WrappedJoiner(new HashJoiner(joinFunction, joiner))) //Construct the new TypedPipe TypedPipe.from[(K, R)](newPipe.project('key, 'value), ('key, 'value))(fd, mode, tuple2Converter) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala index aab6db8429..8f812ef471 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala @@ -303,6 +303,19 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] /** For each key, give the number of values */ def size: This[K, Long] = mapValues { x => 1L }.sum + + /** + * For each key, give the number of unique values. WARNING: May OOM. + * This assumes the values for each key can fit in memory. + */ + def distinctSize: This[K, Long] = toSet[T].mapValues(_.size) + + /** + * For each key, remove duplicate values. WARNING: May OOM. + * This assumes the values for each key can fit in memory. + */ + def distinctValues: This[K, T] = toSet[T].flattenValues + /** * AVOID THIS IF POSSIBLE * For each key, accumulate all the values into a List. WARNING: May OOM diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/LookupJoin.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/LookupJoin.scala index 1ada8894e8..4df34601a0 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/LookupJoin.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/LookupJoin.scala @@ -18,6 +18,24 @@ package com.twitter.scalding.typed import java.io.Serializable +import com.twitter.algebird.Semigroup + +/* + Copyright 2013 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + /** * lookupJoin simulates the behavior of a realtime system attempting * to leftJoin (K, V) pairs against some other value type (JoinedV) @@ -38,7 +56,7 @@ import java.io.Serializable * The entries in the left pipe's tuples have the following * meaning: * - * T: The time at which the (K, W) lookup occurred. + * T: The time at which the (K, W) lookup occurred. * K: the join key. * W: the current value for the join key. * @@ -53,16 +71,55 @@ import java.io.Serializable * right side will return None only if the key is absent, * else, the service will return Some(joinedV). */ + object LookupJoin extends Serializable { + + /** + * This is the "infinite history" join and always joins regardless of how + * much time is between the left and the right + */ + def apply[T: Ordering, K: Ordering, V, JoinedV]( left: TypedPipe[(T, (K, V))], right: TypedPipe[(T, (K, JoinedV))], - reducers: Option[Int] = None): TypedPipe[(T, (K, (V, Option[JoinedV])))] = { + reducers: Option[Int] = None): TypedPipe[(T, (K, (V, Option[JoinedV])))] = + + withWindow(left, right, reducers)((_, _) => true) + + /** + * In this case, the right pipe is fed through a scanLeft doing a Semigroup.plus + * before joined to the left + */ + def rightSumming[T: Ordering, K: Ordering, V, JoinedV: Semigroup](left: TypedPipe[(T, (K, V))], + right: TypedPipe[(T, (K, JoinedV))], + reducers: Option[Int] = None): TypedPipe[(T, (K, (V, Option[JoinedV])))] = + withWindowRightSumming(left, right, reducers)((_, _) => true) + + /** + * This ensures that gate(Tleft, Tright) == true, else the None is emitted + * as the joined value. + * Useful for bounding the time of the join to a recent window + */ + def withWindow[T: Ordering, K: Ordering, V, JoinedV](left: TypedPipe[(T, (K, V))], + right: TypedPipe[(T, (K, JoinedV))], + reducers: Option[Int] = None)(gate: (T, T) => Boolean): TypedPipe[(T, (K, (V, Option[JoinedV])))] = { + + implicit val keepNew: Semigroup[JoinedV] = Semigroup.from { (older, newer) => newer } + withWindowRightSumming(left, right, reducers)(gate) + } + /** + * This ensures that gate(Tleft, Tright) == true, else the None is emitted + * as the joined value, and sums are only done as long as they they come + * within the gate interval as well + */ + def withWindowRightSumming[T: Ordering, K: Ordering, V, JoinedV: Semigroup](left: TypedPipe[(T, (K, V))], + right: TypedPipe[(T, (K, JoinedV))], + reducers: Option[Int] = None)(gate: (T, T) => Boolean): TypedPipe[(T, (K, (V, Option[JoinedV])))] = { /** * Implicit ordering on an either that doesn't care about the - * actual container values, puts the lookups before the service - * writes Since we assume it takes non-zero time to do a lookup. + * actual container values, puts the lookups before the service writes + * Since we assume it takes non-zero time to do a lookup. */ implicit def eitherOrd[T, U]: Ordering[Either[T, U]] = new Ordering[Either[T, U]] { @@ -75,12 +132,15 @@ object LookupJoin extends Serializable { } } - val joined: TypedPipe[(K, (Option[JoinedV], Option[(T, V, Option[JoinedV])]))] = + val joined: TypedPipe[(K, (Option[(T, JoinedV)], Option[(T, V, Option[JoinedV])]))] = left.map { case (t, (k, v)) => (k, (t, Left(v): Either[V, JoinedV])) } - .++(right.map { case (t, (k, joinedV)) => (k, (t, Right(joinedV): Either[V, JoinedV])) }) + .++(right.map { + case (t, (k, joinedV)) => + (k, (t, Right(joinedV): Either[V, JoinedV])) + }) .group .withReducers(reducers.getOrElse(-1)) // -1 means default in scalding - .sortBy(identity) // time then left before right + .sorted /** * Grouping by K leaves values of (T, Either[V, JoinedV]). Sort * by time and scanLeft. The iterator will now represent pairs of @@ -99,30 +159,45 @@ object LookupJoin extends Serializable { * JoinedV is updated and Some(newValue) when a (K, V) * shows up and a new join occurs. */ - (None: Option[JoinedV], None: Option[(T, V, Option[JoinedV])])) { - case ((lastJoined, _), (thisTime, leftOrRight)) => - leftOrRight match { - // Left(v) means that we have a new value from the left - // pipe that we need to join against the current - // "lastJoined" value sitting in scanLeft's state. This - // is equivalent to a lookup on the data in the right - // pipe at time "thisTime". - case Left(v) => (lastJoined, Some((thisTime, v, lastJoined))) - - // Right(joinedV) means that we've received a new value - // to use in the simulated realtime service described in - // the comments above - case Right(joined) => (Some(joined), None) - } - }.toTypedPipe + (Option.empty[(T, JoinedV)], Option.empty[(T, V, Option[JoinedV])])) { + case ((None, result), (time, Left(v))) => { + // The was no value previously + (None, Some((time, v, None))) + } + + case ((prev @ Some((oldt, jv)), result), (time, Left(v))) => { + // Left(v) means that we have a new value from the left + // pipe that we need to join against the current + // "lastJoined" value sitting in scanLeft's state. This + // is equivalent to a lookup on the data in the right + // pipe at time "thisTime". + val filteredJoined = if (gate(time, oldt)) Some(jv) else None + (prev, Some((time, v, filteredJoined))) + } - for { - // Now, get rid of residual state from the scanLeft above: - (k, (_, optV)) <- joined + case ((None, result), (time, Right(joined))) => { + // There was no value before, so we just update to joined + (Some((time, joined)), None) + } + + case ((Some((oldt, oldJ)), result), (time, Right(joined))) => { + // Right(joinedV) means that we've received a new value + // to use in the simulated realtime service + // described in the comments above + // did it fall out of cache? + val nextJoined = if (gate(time, oldt)) Semigroup.plus(oldJ, joined) else joined + (Some((time, nextJoined)), None) + } + }.toTypedPipe - // filter out every event that produced a Right(delta) above, - // leaving only the leftJoin events that occurred above: - (t, v, optJoined) <- optV - } yield (t, (k, (v, optJoined))) + // Now, get rid of residual state from the scanLeft above: + joined.flatMap { + case (k, (_, optV)) => + // filter out every event that produced a Right(delta) above, + // leaving only the leftJoin events that occurred above: + optV.map { + case (t, v, optJoined) => (t, (k, (v, optJoined))) + } + } } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/PartitionedDelimitedSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/PartitionedDelimitedSource.scala index 554e8a0bb5..0008779ff1 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/PartitionedDelimitedSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/PartitionedDelimitedSource.scala @@ -118,5 +118,5 @@ object PartitionedPsv extends PartitionedDelimited { /** Partitioned typed `\1` separated source (commonly used by Pig).*/ object PartitionedOsv extends PartitionedDelimited { - val separator = "\1" + val separator = "\u0001" } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/Sketched.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/Sketched.scala index 47783a71b6..407a586cbd 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/Sketched.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/Sketched.scala @@ -53,12 +53,12 @@ case class Sketched[K, V](pipe: TypedPipe[(K, V)], eps: Double, seed: Int)(implicit serialization: K => Array[Byte], ordering: Ordering[K]) - extends HasReducers { + extends MustHaveReducers { import Sketched._ def serialize(k: K): Array[Byte] = serialization(k) - val reducers = Some(numReducers) + def reducers = Some(numReducers) private lazy implicit val cms = CMS.monoid[Array[Byte]](eps, delta, seed) lazy val sketch: TypedPipe[CMS[Array[Byte]]] = @@ -91,9 +91,9 @@ case class Sketched[K, V](pipe: TypedPipe[(K, V)], case class SketchJoined[K: Ordering, V, V2, R](left: Sketched[K, V], right: TypedPipe[(K, V2)], numReducers: Int)(joiner: (K, V, Iterable[V2]) => Iterator[R]) - extends HasReducers { + extends MustHaveReducers { - val reducers = Some(numReducers) + def reducers = Some(numReducers) //the most of any one reducer we want to try to take up with a single key private val maxReducerFraction = 0.1 diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala index 0b2f918dc0..bbaebf555f 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala @@ -142,7 +142,16 @@ trait TypedPipe[+T] extends Serializable { * Fields API or with Cascading code. * Avoid this if possible. Prefer to write to TypedSink. */ - def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe + final def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { + import Dsl._ + // Ensure we hook into all pipes coming out of the typed API to apply the FlowState's properties on their pipes + asPipe[U](fieldNames).applyFlowConfigProperties(flowDef) + } + + /** + * Provide the internal implementation to get from a typed pipe to a cascading Pipe + */ + protected def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe ///////////////////////////////////////////// // @@ -457,29 +466,26 @@ trait TypedPipe[+T] extends Serializable { * This writes the current TypedPipe into a temporary file * and then opens it after complete so that you can continue from that point */ - def forceToDiskExecution: Execution[TypedPipe[T]] = Execution.fromFn { (conf, mode) => - val flowDef = new FlowDef - mode match { - case _: CascadingLocal => // Local or Test mode - val dest = new MemorySink[T] - write(dest)(flowDef, mode) - - // We can't read until the job finishes - (flowDef, { (js: JobStats) => Future.successful(TypedPipe.from(dest.readResults)) }) - case _: HadoopMode => - // come up with unique temporary filename, use the config here - // TODO: refactor into TemporarySequenceFile class - val tmpDir = conf.get("hadoop.tmp.dir") - .orElse(conf.get("cascading.tmp.dir")) - .getOrElse("/tmp") - - val tmpSeq = tmpDir + "/scalding/snapshot-" + java.util.UUID.randomUUID + ".seq" - val dest = source.TypedSequenceFile[T](tmpSeq) - write(dest)(flowDef, mode) - - (flowDef, { (js: JobStats) => Future.successful(TypedPipe.from(dest)) }) + def forceToDiskExecution: Execution[TypedPipe[T]] = Execution + .getConfigMode + .flatMap { + case (conf, mode) => + mode match { + case _: CascadingLocal => // Local or Test mode + val dest = new MemorySink[T] + writeExecution(dest).map { _ => TypedPipe.from(dest.readResults) } + case _: HadoopMode => + // come up with unique temporary filename, use the config here + // TODO: refactor into TemporarySequenceFile class + val tmpDir = conf.get("hadoop.tmp.dir") + .orElse(conf.get("cascading.tmp.dir")) + .getOrElse("/tmp") + + val tmpSeq = tmpDir + "/scalding/snapshot-" + java.util.UUID.randomUUID + ".seq" + val dest = source.TypedSequenceFile[T](tmpSeq) + writeThrough(dest) + } } - } /** * This gives an Execution that when run evaluates the TypedPipe, @@ -524,12 +530,7 @@ trait TypedPipe[+T] extends Serializable { * into an Execution that is run for anything to happen here. */ def writeExecution(dest: TypedSink[T]): Execution[Unit] = - Execution.fromFn { (conf: Config, m: Mode) => - val fd = new FlowDef - write(dest)(fd, m) - - (fd, { (js: JobStats) => Future.successful(()) }) - } + Execution.write(this, dest) /** * If you want to write to a specific location, and then read from @@ -539,6 +540,20 @@ trait TypedPipe[+T] extends Serializable { writeExecution(dest) .map(_ => TypedPipe.from(dest)) + /** + * If you want to writeThrough to a specific file if it doesn't already exist, + * and otherwise just read from it going forward, use this. + */ + def make[U >: T](dest: FileSource with TypedSink[T] with TypedSource[U]): Execution[TypedPipe[U]] = + Execution.getMode.flatMap { mode => + try { + dest.validateTaps(mode) + Execution.from(TypedPipe.from(dest)) + } catch { + case ivs: InvalidSourceException => writeThrough(dest) + } + } + /** Just keep the keys, or ._1 (if this type is a Tuple2) */ def keys[K](implicit ev: <:<[T, (K, Any)]): TypedPipe[K] = // avoid capturing ev in the closure: @@ -701,7 +716,7 @@ final case object EmptyTypedPipe extends TypedPipe[Nothing] { override def ++[U >: Nothing](other: TypedPipe[U]): TypedPipe[U] = other - override def toPipe[U >: Nothing](fieldNames: Fields)(implicit fd: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = + override def asPipe[U >: Nothing](fieldNames: Fields)(implicit fd: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = IterableSource(Iterable.empty, fieldNames)(setter, singleConverter[U]).read(fd, mode) override def toIterableExecution: Execution[Iterable[Nothing]] = Execution.from(Iterable.empty) @@ -787,7 +802,7 @@ final case class IterablePipe[T](iterable: Iterable[T]) extends TypedPipe[T] { }) } - override def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = + override def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = // It is slightly more efficient to use this rather than toSourcePipe.toPipe(fieldNames) IterableSource[U](iterable, fieldNames)(setter, singleConverter[U]).read(flowDef, mode) @@ -845,16 +860,17 @@ class TypedPipeFactory[T] private (@transient val next: NoStackAndThen[(FlowDef, override def sumByLocalKeys[K, V](implicit ev: T <:< (K, V), sg: Semigroup[V]) = andThen(_.sumByLocalKeys[K, V]) - override def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]) = + override def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]) = // unwrap in a loop, without recursing unwrap(this).toPipe[U](fieldNames)(flowDef, mode, setter) - override def toIterableExecution: Execution[Iterable[T]] = Execution.factory { (conf, mode) => - // This can only terminate in TypedPipeInst, which will - // keep the reference to this flowDef - val flowDef = new FlowDef - val nextPipe = unwrap(this)(flowDef, mode) - nextPipe.toIterableExecution + override def toIterableExecution: Execution[Iterable[T]] = Execution.getConfigMode.flatMap { + case (conf, mode) => + // This can only terminate in TypedPipeInst, which will + // keep the reference to this flowDef + val flowDef = new FlowDef + val nextPipe = unwrap(this)(flowDef, mode) + nextPipe.toIterableExecution } @annotation.tailrec @@ -919,17 +935,14 @@ class TypedPipeInst[T] private[scalding] (@transient inpipe: Pipe, * This approach is more efficient than untyped scalding because we * don't use TupleConverters/Setters after each map. */ - override def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, m: Mode, setter: TupleSetter[U]): Pipe = { + override def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, m: Mode, setter: TupleSetter[U]): Pipe = { import Dsl.flowDefToRichFlowDef checkMode(m) flowDef.mergeFrom(localFlowDef) RichPipe(inpipe).flatMapTo[TupleEntry, U](fields -> fieldNames)(flatMapFn) } - override def toIterableExecution: Execution[Iterable[T]] = Execution.factory { (conf, m) => - // To convert from java iterator to scala below - import scala.collection.JavaConverters._ - checkMode(m) + override def toIterableExecution: Execution[Iterable[T]] = openIfHead match { // TODO: it might be good to apply flatMaps locally, // since we obviously need to iterate all, @@ -937,12 +950,18 @@ class TypedPipeInst[T] private[scalding] (@transient inpipe: Pipe, // for us. So unwind until you hit the first filter, snapshot, // then apply the unwound functions case Some((tap, fields, Converter(conv))) => - Execution.from(new Iterable[T] { - def iterator = m.openForRead(conf, tap).asScala.map(tup => conv(tup.selectEntry(fields))) - }) + // To convert from java iterator to scala below + import scala.collection.JavaConverters._ + Execution.getConfigMode.map { + case (conf, m) => + // Verify the mode has not changed due to invalid TypedPipe DAG construction + checkMode(m) + new Iterable[T] { + def iterator = m.openForRead(conf, tap).asScala.map(tup => conv(tup.selectEntry(fields))) + } + } case _ => forceToDiskExecution.flatMap(_.toIterableExecution) } - } } final case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) extends TypedPipe[T] { @@ -974,14 +993,6 @@ final case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) ext override def fork: TypedPipe[T] = MergedTypedPipe(left.fork, right.fork) - /* - * This relies on the fact that two executions that are zipped will run in the - * same cascading flow, so we don't have to worry about it here. - */ - override def forceToDiskExecution = - left.forceToDiskExecution.zip(right.forceToDiskExecution) - .map { case (l, r) => l ++ r } - @annotation.tailrec private def flattenMerge(toFlatten: List[TypedPipe[T]], acc: List[TypedPipe[T]])(implicit fd: FlowDef, m: Mode): List[TypedPipe[T]] = toFlatten match { @@ -991,7 +1002,7 @@ final case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) ext case Nil => acc } - override def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { + override def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { /* * Cascading can't handle duplicate pipes in merges. What we do here is see if any pipe appears * multiple times and if it does we can do self merges using flatMap. @@ -1018,20 +1029,13 @@ final case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) ext } } - /** - * This relies on the fact that two executions that are zipped will run in the - * same cascading flow, so we don't have to worry about it here. - */ - override def toIterableExecution: Execution[Iterable[T]] = - left.toIterableExecution.zip(right.toIterableExecution) - .map { case (l, r) => l ++ r } - + override def toIterableExecution: Execution[Iterable[T]] = forceToDiskExecution.flatMap(_.toIterableExecution) override def hashCogroup[K, V, W, R](smaller: HashJoinable[K, W])(joiner: (K, V, Iterable[W]) => Iterator[R])(implicit ev: TypedPipe[T] <:< TypedPipe[(K, V)]): TypedPipe[(K, R)] = MergedTypedPipe(left.hashCogroup(smaller)(joiner), right.hashCogroup(smaller)(joiner)) } class WithOnComplete[T](typedPipe: TypedPipe[T], fn: () => Unit) extends TypedPipe[T] { - override def toPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]) = { + override def asPipe[U >: T](fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]) = { val pipe = typedPipe.toPipe[U](fieldNames)(flowDef, mode, setter) new Each(pipe, Fields.ALL, new CleanupIdentityFunction(fn), Fields.REPLACE) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipeDiff.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipeDiff.scala new file mode 100644 index 0000000000..c49582b3e1 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipeDiff.scala @@ -0,0 +1,125 @@ +package com.twitter.scalding.typed + +import java.io.{ BufferedWriter, File, FileWriter } + +import com.twitter.scalding.Execution + +import scala.reflect.ClassTag + +/** + * Some methods for comparing two typed pipes and finding out the difference between them. + * + * Has support for the normal case where the typed pipes are pipes of objects usable as keys + * in scalding (have an ordering, proper equals and hashCode), as well as some special cases + * for dealing with Arrays and thrift objects. + * + * See diffByHashCode for comparing typed pipes of objects that have no ordering but a stable hash code + * (such as Scrooge thrift). + * + * See diffByGroup for comparing typed pipes of objects that have no ordering *and* an unstable hash code. + */ +object TypedPipeDiff { + + /** + * Returns a mapping from T to a count of the occurrences of T in the left and right pipes, + * only for cases where the counts are not equal. + * + * Requires that T have an ordering and a hashCode and equals that is stable across JVMs (not reference based). + * See diffArrayPipes for diffing pipes of arrays, since arrays do not meet these requirements by default. + */ + def diff[T: Ordering](left: TypedPipe[T], right: TypedPipe[T], reducers: Option[Int] = None): UnsortedGrouped[T, (Long, Long)] = { + val lefts = left.map { x => (x, (1L, 0L)) } + val rights = right.map { x => (x, (0L, 1L)) } + val counts = (lefts ++ rights).sumByKey + val diff = counts.filter { case (key, (lCount, rCount)) => lCount != rCount } + reducers.map(diff.withReducers).getOrElse(diff) + } + + /** + * Same as diffByHashCode, but takes care to wrap the Array[T] in a wrapper, + * which has the correct hashCode and equals needed. This does not involve + * copying the arrays, just wrapping them, and is specialized for primitive arrays. + */ + def diffArrayPipes[T: ClassTag](left: TypedPipe[Array[T]], + right: TypedPipe[Array[T]], + reducers: Option[Int] = None): TypedPipe[(Array[T], (Long, Long))] = { + + // cache this instead of reflecting on every single array + val wrapFn = HashEqualsArrayWrapper.wrapFn[T] + + diffByHashCode(left.map(wrapFn), right.map(wrapFn), reducers) + .map { case (k, counts) => (k.wrapped, counts) } + } + + /** + * NOTE: Prefer diff over this method if you can find or construct an Ordering[T]. + * + * Returns a mapping from T to a count of the occurrences of T in the left and right pipes, + * only for cases where the counts are not equal. + * + * This implementation does not require an ordering on T, but does require a function (groupByFn) + * that extracts a value of type K (which has an ordering) from a record of type T. + * + * The groupByFn should be something that partitions records as evenly as possible, + * because all unique records that result in the same groupByFn value will be materialized into an in memory map. + * + * groupByFn must be a pure function, such that: + * x == y implies that groupByFn(x) == groupByFn(y) + * + * T must have a hash code suitable for use in a hash map on a single JVM (doesn't have to be stable cross JVM) + * K must have a hash code this *is* stable across JVMs. + * K must have an ordering. + * + * Example groupByFns would be x => x.hashCode, assuming x's hashCode is stable across jvms, + * or maybe x => x.timestamp, if x's hashCode is not stable, assuming there's shouldn't be too + * many records with the same timestamp. + */ + def diffByGroup[T, K: Ordering]( + left: TypedPipe[T], + right: TypedPipe[T], + reducers: Option[Int] = None)(groupByFn: T => K): TypedPipe[(T, (Long, Long))] = { + + val lefts = left.map { t => (groupByFn(t), Map(t -> (1L, 0L))) } + val rights = right.map { t => (groupByFn(t), Map(t -> (0L, 1L))) } + + val diff = (lefts ++ rights) + .sumByKey + .flattenValues + .filter { case (k, (t, (lCount, rCount))) => lCount != rCount } + + reducers.map(diff.withReducers).getOrElse(diff).values + } + + /** + * NOTE: Prefer diff over this method if you can find or construct an Ordering[T]. + * + * Same as diffByGroup but uses T.hashCode as the groupByFn + * + * This method does an exact diff, it does not use the hashCode as a proxy for equality. + */ + def diffByHashCode[T]( + left: TypedPipe[T], + right: TypedPipe[T], + reducers: Option[Int] = None): TypedPipe[(T, (Long, Long))] = diffByGroup(left, right, reducers)(_.hashCode) + + object Enrichments { + + implicit class Diff[T](val left: TypedPipe[T]) extends AnyVal { + + def diff(right: TypedPipe[T], reducers: Option[Int] = None)(implicit ev: Ordering[T]): UnsortedGrouped[T, (Long, Long)] = + TypedPipeDiff.diff(left, right, reducers) + + def diffByGroup[K: Ordering](right: TypedPipe[T], reducers: Option[Int] = None)(groupByFn: T => K): TypedPipe[(T, (Long, Long))] = + TypedPipeDiff.diffByGroup(left, right, reducers)(groupByFn) + + def diffByHashCode(right: TypedPipe[T], reducers: Option[Int] = None): TypedPipe[(T, (Long, Long))] = TypedPipeDiff.diffByHashCode(left, right, reducers) + } + + implicit class DiffArray[T](val left: TypedPipe[Array[T]]) extends AnyVal { + + def diffArrayPipes(right: TypedPipe[Array[T]], reducers: Option[Int] = None)(implicit ev: ClassTag[T]): TypedPipe[(Array[T], (Long, Long))] = + TypedPipeDiff.diffArrayPipes(left, right, reducers) + } + + } +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/WithReducers.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/WithReducers.scala index 8b42ff30dc..ff4c350236 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/WithReducers.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/WithReducers.scala @@ -23,6 +23,14 @@ trait HasReducers { def reducers: Option[Int] } +/** + * used for types that must know how many reducers they need + * e.g. Sketched + */ +trait MustHaveReducers extends HasReducers { + def reducers: Some[Int] +} + /** * used for objects that may _set_ how many reducers they need * e.g. CoGrouped, Grouped, SortedGrouped, UnsortedGrouped diff --git a/scalding-core/src/test/scala/com/twitter/scalding/BlockJoinTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/BlockJoinTest.scala index c6353973cd..26a4f671d7 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/BlockJoinTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/BlockJoinTest.scala @@ -19,8 +19,6 @@ import org.scalatest.{ Matchers, WordSpec } import cascading.pipe.joiner._ -import java.lang.reflect.InvocationTargetException - import scala.collection.mutable.Buffer class InnerProductJob(args: Args) extends Job(args) { diff --git a/scalding-core/src/test/scala/com/twitter/scalding/CascadeTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/CascadeTest.scala index d0a280d7f1..819e30c62c 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/CascadeTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/CascadeTest.scala @@ -23,7 +23,6 @@ import scala.io.Source.fromFile import java.io.File import cascading.cascade.Cascade import cascading.flow.FlowSkipIfSinkNotStale -import cascading.tuple.Fields class Job1(args: Args) extends Job(args) { Tsv(args("input0"), ('line)).pipe.map[String, String]('line -> 'line)((x: String) => "job1:" + x).write(Tsv(args("output0"), fields = 'line)) diff --git a/scalding-core/src/test/scala/com/twitter/scalding/CoGroupTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/CoGroupTest.scala index c8b9cfc659..5d615be772 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/CoGroupTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/CoGroupTest.scala @@ -15,7 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import cascading.pipe.joiner._ import org.scalatest.{ WordSpec, Matchers } class StarJoinJob(args: Args) extends Job(args) { diff --git a/scalding-core/src/test/scala/com/twitter/scalding/ConfigTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/ConfigTest.scala index 9cc897317e..39a1552ae2 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/ConfigTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/ConfigTest.scala @@ -17,10 +17,8 @@ package com.twitter.scalding import org.scalatest.{ WordSpec, Matchers } import org.scalacheck.Arbitrary -import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Properties import org.scalacheck.Prop.forAll -import org.scalacheck.Gen._ import scala.util.Success @@ -47,6 +45,35 @@ class ConfigTest extends WordSpec with Matchers { stillOld should contain (date) new2 shouldBe newConf } + "adding UniqueIDs works" in { + assert(Config.empty.getUniqueIds.size === 0) + val (id, conf) = Config.empty.ensureUniqueId + assert(conf.getUniqueIds === (Set(id))) + } + "Default serialization should have tokens" in { + Config.default.getCascadingSerializationTokens should not be empty + Config.default.getCascadingSerializationTokens + .values + .map(Class.forName) + .filter(c => c.isPrimitive || c.isArray) shouldBe empty + + Config.empty.getCascadingSerializationTokens shouldBe empty + + // tokenClasses are a subset that don't include primites or arrays. + val tokenClasses = Config.default.getCascadingSerializationTokens.values.toSet + val kryoClasses = Config.default.getKryoRegisteredClasses.map(_.getName) + // Tokens are a subset of Kryo registered classes + (kryoClasses & tokenClasses) shouldBe tokenClasses + // the only Kryo classes we don't assign tokens for are the primitives + array + (kryoClasses -- tokenClasses).forall { c => + // primitives cannot be forName'd + val prim = Set(classOf[Boolean], classOf[Byte], classOf[Short], + classOf[Int], classOf[Long], classOf[Float], classOf[Double], classOf[Char]) + .map(_.getName) + + prim(c) || Class.forName(c).isArray + } shouldBe true + } } } @@ -68,4 +95,10 @@ object ConfigProps extends Properties("Config") { val testKeys = c1.toMap.keySet | c2.toMap.keySet ++ keys testKeys.forall { k => merged.get(k) == c2.get(k).orElse(c1.get(k)) } } + property("adding many UniqueIDs works") = forAll { (l: List[String]) => + val uids = l.filterNot { s => s.isEmpty || s.contains(",") }.map(UniqueID(_)) + (uids.foldLeft(Config.empty) { (conf, id) => + conf.addUniqueId(id) + }.getUniqueIds == uids.toSet) + } } diff --git a/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala index 1244e42771..666128a04e 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala @@ -1320,33 +1320,6 @@ class NormalizeTest extends WordSpec with Matchers { } } -class ApproxUniqJob(args: Args) extends Job(args) { - Tsv("in", ('x, 'y)) - .read - .groupBy('x) { _.approxUniques('y -> 'ycnt) } - .write(Tsv("out")) -} - -class ApproxUniqTest extends WordSpec with Matchers { - import Dsl._ - - "A ApproxUniqJob" should { - val input = (1 to 1000).flatMap { i => List(("x0", i), ("x1", i)) }.toList - JobTest(new ApproxUniqJob(_)) - .source(Tsv("in", ('x, 'y)), input) - .sink[(String, Double)](Tsv("out")) { outBuf => - "must approximately count" in { - outBuf should have size 2 - val kvresult = outBuf.groupBy { _._1 }.mapValues { _.head._2 } - kvresult("x0") shouldBe 1000.0 +- 30.0 //We should be 1%, but this is on average, so - kvresult("x1") shouldBe 1000.0 +- 30.0 //We should be 1%, but this is on average, so - } - } - .run - .finish - } -} - class ForceToDiskJob(args: Args) extends Job(args) { val x = Tsv("in", ('x, 'y)) .read diff --git a/scalding-core/src/test/scala/com/twitter/scalding/CumulitiveSumTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/CumulativeSumTest.scala similarity index 98% rename from scalding-core/src/test/scala/com/twitter/scalding/CumulitiveSumTest.scala rename to scalding-core/src/test/scala/com/twitter/scalding/CumulativeSumTest.scala index 7f8fb5faf7..aa40e0eb91 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/CumulitiveSumTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/CumulativeSumTest.scala @@ -2,8 +2,6 @@ package com.twitter.scalding import org.scalatest.WordSpec -import com.twitter.scalding._ - import com.twitter.scalding.typed.CumulativeSum._ class AddRankingWithCumulativeSum(args: Args) extends Job(args) { diff --git a/scalding-core/src/test/scala/com/twitter/scalding/DistinctByTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/DistinctByTest.scala index d5405b40ea..ed4be54953 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/DistinctByTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/DistinctByTest.scala @@ -17,11 +17,8 @@ package com.twitter.scalding import com.twitter.scalding.typed.CoGrouped.distinctBy -import org.scalacheck.Arbitrary -import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Properties import org.scalacheck.Prop.forAll -import org.scalacheck.Gen._ object DistinctByProps extends Properties("CoGrouped.DistinctBy") { diff --git a/scalding-core/src/test/scala/com/twitter/scalding/ExecutionAppProperties.scala b/scalding-core/src/test/scala/com/twitter/scalding/ExecutionAppProperties.scala index fdc8d6d028..2a7524b315 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/ExecutionAppProperties.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/ExecutionAppProperties.scala @@ -15,10 +15,8 @@ limitations under the License. */ package com.twitter.scalding -import org.scalacheck.Arbitrary import org.scalacheck.Properties import org.scalacheck.Prop.forAll -import org.scalacheck.Gen.choose import org.scalacheck.Prop._ // Be careful here in that Array[String] equality isn't contents based. its java referenced based. diff --git a/scalding-core/src/test/scala/com/twitter/scalding/FileSourceTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/FileSourceTest.scala index f6872aea7e..da84e777c9 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/FileSourceTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/FileSourceTest.scala @@ -170,7 +170,7 @@ object TestFileSource extends FileSource { import TestPath.testfsPathRoot override def hdfsPaths: Iterable[String] = Iterable.empty - override def localPath: String = "" + override def localPaths: Iterable[String] = Iterable.empty val conf = new Configuration() @@ -180,7 +180,7 @@ object TestFileSource extends FileSource { object TestSuccessFileSource extends FileSource with SuccessFileSource { import TestPath.testfsPathRoot override def hdfsPaths: Iterable[String] = Iterable.empty - override def localPath: String = "" + override def localPaths: Iterable[String] = Iterable.empty val conf = new Configuration() diff --git a/scalding-core/src/test/scala/com/twitter/scalding/KryoTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/KryoTest.scala index aabae9197a..a0c97a8112 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/KryoTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/KryoTest.scala @@ -15,8 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import com.twitter.scalding.serialization._ - import org.scalatest.{ Matchers, WordSpec } import java.io.{ ByteArrayOutputStream => BOS } @@ -39,6 +37,7 @@ import com.twitter.chill.hadoop.HadoopConfig import com.twitter.chill.hadoop.KryoSerialization import org.apache.hadoop.conf.Configuration + /* * This is just a test case for Kryo to deal with. It should * be outside KryoTest, otherwise the enclosing class, KryoTest diff --git a/scalding-core/src/test/scala/com/twitter/scalding/LookupJoinTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/LookupJoinTest.scala index bb616a9fc0..754c7f3633 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/LookupJoinTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/LookupJoinTest.scala @@ -18,7 +18,26 @@ package com.twitter.scalding import com.twitter.scalding.typed.LookupJoin import org.scalatest.{ Matchers, WordSpec } +import com.twitter.algebird.{ Monoid, Semigroup, Group } + +object LookupJoinedTest { + + // Not defined if there is a collision in K and T, so make those unique: + def genList(maxTime: Int, maxKey: Int, sz: Int): List[(Int, Int, Int)] = { + val rng = new java.util.Random + (0 until sz).view.map { _ => + (rng.nextInt(maxTime), rng.nextInt(maxKey), rng.nextInt) + } + .groupBy { case (t, k, v) => (t, k) } + .mapValues(_.headOption.toList) + .values + .flatten + .toList + } +} + class LookupJoinerJob(args: Args) extends Job(args) { + import TDsl._ val in0 = TypedTsv[(Int, Int, Int)]("input0") @@ -31,42 +50,168 @@ class LookupJoinerJob(args: Args) extends Job(args) { (t.toString, k.toString, v.toString, opt.toString) } .write(TypedTsv[(String, String, String, String)]("output")) + + LookupJoin.rightSumming(TypedPipe.from(in0).map { case (t, k, v) => (t, (k, v)) }, + TypedPipe.from(in1).map { case (t, k, v) => (t, (k, v)) }) + .map { + case (t, (k, (v, opt))) => + (t.toString, k.toString, v.toString, opt.toString) + } + .write(TypedTsv[(String, String, String, String)]("output2")) } class LookupJoinedTest extends WordSpec with Matchers { + import Dsl._ + import LookupJoinedTest.genList + def lookupJoin[T: Ordering, K, V, W](in0: Iterable[(T, K, V)], in1: Iterable[(T, K, W)]) = { - // super inefficient, but easy to verify: + val serv = in1.groupBy(_._2) + def lookup(t: T, k: K): Option[W] = { + val ord = Ordering.by { tkw: (T, K, W) => tkw._1 } + serv.get(k).flatMap { in1s => + in1s.filter { case (t1, _, _) => Ordering[T].lt(t1, t) } + .reduceOption(ord.max(_, _)) + .map { + _._3 + } + } + } + in0.map { case (t, k, v) => (t.toString, k.toString, v.toString, lookup(t, k).toString) } + } + + def lookupSumJoin[T: Ordering, K, V, W: Semigroup](in0: Iterable[(T, K, V)], in1: Iterable[(T, K, W)]) = { + implicit val ord: Ordering[(T, K, W)] = Ordering.by { + _._1 + } + val serv = in1.groupBy(_._2).mapValues { + _.toList + .sorted + .scanLeft(None: Option[(T, K, W)]) { (old, newer) => + old.map { case (_, _, w) => (newer._1, newer._2, Semigroup.plus(w, newer._3)) } + .orElse(Some(newer)) + } + .filter { + _.isDefined + } + .map { + _.get + } + }.toMap // Force the map + def lookup(t: T, k: K): Option[W] = { - implicit val ord = Ordering.by { tkw: (T, K, W) => tkw._1 } - in1.filter { case (t1, k1, _) => (k1 == k) && Ordering[T].lt(t1, t) } - .reduceOption(Ordering[(T, K, W)].max(_, _)) - .map { _._3 } + val ord = Ordering.by { tkw: (T, K, W) => tkw._1 } + serv.get(k).flatMap { in1s => + in1s.filter { case (t1, _, _) => Ordering[T].lt(t1, t) } + .reduceOption(ord.max(_, _)) + .map { + _._3 + } + } } in0.map { case (t, k, v) => (t.toString, k.toString, v.toString, lookup(t, k).toString) } } + "A LookupJoinerJob" should { "correctly lookup" in { - val rng = new java.util.Random - val MAX_KEY = 10 - def genList(sz: Int): List[(Int, Int, Int)] = { - (0 until sz).map { _ => - (rng.nextInt, rng.nextInt(MAX_KEY), rng.nextInt) - }.toList - } - val in0 = genList(1000) - val in1 = genList(1000) + val MAX_KEY = 100 + val VAL_COUNT = 10000 + val in0 = genList(Int.MaxValue, MAX_KEY, VAL_COUNT) + val in1 = genList(Int.MaxValue, MAX_KEY, VAL_COUNT) JobTest(new LookupJoinerJob(_)) .source(TypedTsv[(Int, Int, Int)]("input0"), in0) .source(TypedTsv[(Int, Int, Int)]("input1"), in1) .sink[(String, String, String, String)]( TypedTsv[(String, String, String, String)]("output")) { outBuf => - outBuf.toSet shouldBe (lookupJoin(in0, in1).toSet) - in0 should have size (outBuf.size) + outBuf.toSet should equal (lookupJoin(in0, in1).toSet) + in0.size should equal (outBuf.size) + } + .sink[(String, String, String, String)]( + TypedTsv[(String, String, String, String)]("output2")) { outBuf => + outBuf.toSet should equal(lookupSumJoin(in0, in1).toSet) + in0.size should equal(outBuf.size) + } + .run + //.runHadoop + .finish + } + } +} + +class WindowLookupJoinerJob(args: Args) extends Job(args) { + + import TDsl._ + + val in0 = TypedTsv[(Int, Int, Int)]("input0") + val in1 = TypedTsv[(Int, Int, Int)]("input1") + val window = args("window").toInt + + def gate(left: Int, right: Int) = + (left.toLong - right.toLong) < window + + LookupJoin.withWindow(TypedPipe.from(in0).map { case (t, k, v) => (t, (k, v)) }, + TypedPipe.from(in1).map { case (t, k, v) => (t, (k, v)) })(gate _) + .map { + case (t, (k, (v, opt))) => + (t.toString, k.toString, v.toString, opt.toString) + } + .write(TypedTsv[(String, String, String, String)]("output")) +} + +class WindowLookupJoinedTest extends WordSpec with Matchers { + + import Dsl._ + import LookupJoinedTest.genList + + def windowLookupJoin[K, V, W](in0: Iterable[(Int, K, V)], in1: Iterable[(Int, K, W)], win: Int) = { + val serv = in1.groupBy(_._2) + // super inefficient, but easy to verify: + def lookup(t: Int, k: K): Option[W] = { + val ord = Ordering.by { tkw: (Int, K, W) => tkw._1 } + serv.get(k).flatMap { in1s => + in1s.filter { + case (t1, _, _) => + (t1 < t) && ((t.toLong - t1.toLong) < win) + } + .reduceOption(ord.max(_, _)) + .map { + _._3 + } + } + } + in0.map { case (t, k, v) => (t.toString, k.toString, v.toString, lookup(t, k).toString) } + } + + "A WindowLookupJoinerJob" should { + //Set up the job: + "correctly lookup" in { + val MAX_KEY = 10 + val MAX_TIME = 10000 + val sz: Int = 10000; + val in0 = genList(MAX_TIME, MAX_KEY, 10000) + val in1 = genList(MAX_TIME, MAX_KEY, 10000) + JobTest(new WindowLookupJoinerJob(_)) + .arg("window", "100") + .source(TypedTsv[(Int, Int, Int)]("input0"), in0) + .source(TypedTsv[(Int, Int, Int)]("input1"), in1) + .sink[(String, String, String, String)]( + TypedTsv[(String, String, String, String)]("output")) { outBuf => + val results = outBuf.toList.sorted + val correct = windowLookupJoin(in0, in1, 100).toList.sorted + def some(it: List[(String, String, String, String)]) = + it.filter(_._4.startsWith("Some")) + + def none(it: List[(String, String, String, String)]) = + it.filter(_._4.startsWith("None")) + + some(results) shouldBe (some(correct)) + none(results) shouldBe (none(correct)) + in0.size should equal (outBuf.size) } .run - .runHadoop + //.runHadoop .finish } } } + diff --git a/scalding-core/src/test/scala/com/twitter/scalding/PartitionSourceTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/PartitionSourceTest.scala index 204f0f96fa..b9adb9a037 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/PartitionSourceTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/PartitionSourceTest.scala @@ -26,7 +26,6 @@ import cascading.tuple.Fields import cascading.tuple.TupleEntry import cascading.util.Util import cascading.tap.partition.Partition -import cascading.tap.partition.DelimitedPartition import com.twitter.scalding.{ PartitionedTsv => StandardPartitionedTsv, _ } diff --git a/scalding-core/src/test/scala/com/twitter/scalding/ReduceOperationsTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/ReduceOperationsTest.scala index 2a1c484b01..bdd29eebad 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/ReduceOperationsTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/ReduceOperationsTest.scala @@ -16,7 +16,6 @@ limitations under the License. package com.twitter.scalding import org.scalatest.{ Matchers, WordSpec } -import com.twitter.scalding._ class SortWithTakeJob(args: Args) extends Job(args) { try { @@ -79,6 +78,9 @@ class ApproximateUniqueCountJob(args: Args) extends Job(args) { .groupBy('category) { _.approximateUniqueCount[String]('os -> 'os_count) } + .map('os_count -> 'os_count) { + osCount: Double => osCount.toLong + } .write(Tsv("output0")) } catch { case e: Exception => e.printStackTrace() @@ -147,12 +149,12 @@ class ReduceOperationsTest extends WordSpec with Matchers { JobTest(new ApproximateUniqueCountJob(_)) .source(Tsv("input0", ('category, 'model, 'os)), inputData) - .sink[(String, Double)](Tsv("output0")) { buf => + .sink[(String, Long)](Tsv("output0")) { buf => "grouped OS count" in { - val whatWeWant: Map[String, Double] = Map( - "laptop" -> 1.0, - "mobile" -> 2.0) - val whatWeGet: Map[String, Double] = buf.toMap + val whatWeWant: Map[String, Long] = Map( + "laptop" -> 1, + "mobile" -> 2) + val whatWeGet: Map[String, Long] = buf.toMap whatWeGet should have size 2 whatWeGet.get("laptop").getOrElse("apples") shouldBe (whatWeWant.get("laptop").getOrElse("oranges")) whatWeGet.get("mobile").getOrElse("apples") shouldBe (whatWeWant.get("mobile").getOrElse("oranges")) diff --git a/scalding-core/src/test/scala/com/twitter/scalding/ScanLeftTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/ScanLeftTest.scala index 21970aa1bf..154a20536c 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/ScanLeftTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/ScanLeftTest.scala @@ -1,7 +1,6 @@ package com.twitter.scalding import org.scalatest.{ Matchers, WordSpec } -import com.twitter.scalding._ /** * Simple Example: First group data by gender and then sort by height reverse order. diff --git a/scalding-core/src/test/scala/com/twitter/scalding/SideEffectTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/SideEffectTest.scala index 83ce7bfa30..54f3351e8c 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/SideEffectTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/SideEffectTest.scala @@ -15,8 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import scala.annotation.tailrec -import cascading.pipe._ import org.scalatest.{ Matchers, WordSpec } /* diff --git a/scalding-core/src/test/scala/com/twitter/scalding/SkewJoinTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/SkewJoinTest.scala index eb6e0b6560..af3140b684 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/SkewJoinTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/SkewJoinTest.scala @@ -17,10 +17,6 @@ package com.twitter.scalding import org.scalatest.{ Matchers, WordSpec } -import cascading.pipe.joiner._ - -import java.lang.reflect.InvocationTargetException - import scala.collection.mutable.Buffer class SkewJoinJob(args: Args) extends Job(args) { @@ -68,9 +64,9 @@ object JoinTestHelper { .arg("replicationFactor", replicationFactor.toString) .arg("replicator", replicator.toString) .source(Tsv("input0"), generateInput(1000, 100)) - .source(Tsv("input1"), generateInput(1000, 100)) - .sink[(Int, Int, Int, Int, Int, Int)](Tsv("output")) { outBuf => skewResult ++ outBuf } - .sink[(Int, Int, Int, Int, Int, Int)](Tsv("jws-output")) { outBuf => innerResult ++ outBuf } + .source(Tsv("input1"), generateInput(100, 100)) + .sink[(Int, Int, Int, Int, Int, Int)](Tsv("output")) { outBuf => skewResult ++= outBuf } + .sink[(Int, Int, Int, Int, Int, Int)](Tsv("jws-output")) { outBuf => innerResult ++= outBuf } .run //.runHadoop //this takes MUCH longer to run. Commented out by default, but tests pass on my machine .finish diff --git a/scalding-core/src/test/scala/com/twitter/scalding/StringUtilityTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/StringUtilityTest.scala new file mode 100644 index 0000000000..3961825fb8 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/StringUtilityTest.scala @@ -0,0 +1,67 @@ +package com.twitter.scalding + +import org.scalatest.{ PropSpec, Matchers, WordSpec } +import org.scalacheck.{ Arbitrary, Properties } +import org.scalacheck.Prop.forAll +import org.scalatest.prop.Checkers +import org.scalacheck.Gen + +import scala.collection.mutable.ListBuffer + +class StringUtilityTest extends WordSpec with Matchers { + "fastSplitTest" should { + "be able to split white space" in { + val text1 = "this is good time" + val res1 = StringUtility.fastSplit(text1, " ") // split single white space + res1 should be { + Seq("this", "is", "good", "time") + } + } + } + "be able to split other separators" in { + val text2 = "a:b:c:d:" + val res2 = StringUtility.fastSplit(text2, ":") + res2 should be { + Seq("a", "b", "c", "d", "") + } + } + "be able to split only one separators" in { + val text2 = "a@" + val res2 = StringUtility.fastSplit(text2, "@") + res2 should be { + Seq("a", "") + } + } + "be able to split when separator doesn't show up" in { + val text2 = "a" + val res2 = StringUtility.fastSplit(text2, "@") + res2 should be { + Seq("a") + } + } +} + +class StringUtilityPropertyTest extends PropSpec with Checkers { + val randomStringGen = for { + s <- Gen.pick(5, List.fill(100)(List("k", "l", "m", "x", "//.", "@")).flatten) + + } yield s + + // test for one separator and two + val randomSeparator = for { + s <- Gen.oneOf("@@", "@", "x", "//.") + } yield s + + property("fastSplit(s, sep) should match s.split(sep, -1) for non-regex sep") { + check { + forAll(randomStringGen, randomSeparator) { + (str, separator) => + val t = str.mkString("") + val r1 = t.split(separator, -1).toList + val r2 = StringUtility.fastSplit(t, separator) + r1 == r2 + } + } + } + +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/TemplateSourceTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/TemplateSourceTest.scala index 74809e6b4a..366b5c6676 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/TemplateSourceTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/TemplateSourceTest.scala @@ -21,9 +21,6 @@ import scala.io.{ Source => ScalaSource } import org.scalatest.{ Matchers, WordSpec } -import cascading.tap.SinkMode -import cascading.tuple.Fields - class TemplateTestJob(args: Args) extends Job(args) { try { Tsv("input", ('col1, 'col2)).read.write(TemplatedTsv("base", "%s", 'col1)) diff --git a/scalding-core/src/test/scala/com/twitter/scalding/TestTapFactoryTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/TestTapFactoryTest.scala index 152d822a2c..b071edabfc 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/TestTapFactoryTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/TestTapFactoryTest.scala @@ -2,7 +2,6 @@ package com.twitter.scalding import cascading.tap.Tap import cascading.tuple.{ Fields, Tuple } -import java.lang.IllegalArgumentException import scala.collection.mutable.Buffer import org.scalatest.{ Matchers, WordSpec } diff --git a/scalding-core/src/test/scala/com/twitter/scalding/TypedDelimitedTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/TypedDelimitedTest.scala index 4197e18e3f..16435b180f 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/TypedDelimitedTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/TypedDelimitedTest.scala @@ -16,7 +16,6 @@ limitations under the License. package com.twitter.scalding import org.scalatest.{ Matchers, WordSpec } -import com.twitter.scalding._ import com.twitter.scalding.source.DailySuffixTypedTsv class TypedTsvJob(args: Args) extends Job(args) { diff --git a/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala index 3a345c051c..10ff9e67bf 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala @@ -194,6 +194,41 @@ class TypedPipeDistinctByTest extends WordSpec with Matchers { } } +class TypedPipeGroupedDistinctJob(args: Args) extends Job(args) { + val groupedTP = Tsv("inputFile").read.toTypedPipe[(Int, Int)](0, 1) + .group + + groupedTP + .distinctValues + .write(TypedTsv[(Int, Int)]("outputFile1")) + groupedTP + .distinctSize + .write(TypedTsv[(Int, Long)]("outputFile2")) +} + +class TypedPipeGroupedDistinctJobTest extends WordSpec with Matchers { + import Dsl._ + "A TypedPipeGroupedDistinctJob" should { + JobTest(new TypedPipeGroupedDistinctJob(_)) + .source(Tsv("inputFile"), List((0, 0), (0, 1), (0, 1), (1, 0), (1, 1))) + .sink[(Int, Int)](TypedTsv[(Int, Int)]("outputFile1")){ outputBuffer => + val outSet = outputBuffer.toSet + "correctly generate unique items" in { + outSet should have size 4 + } + } + .sink[(Int, Int)](TypedTsv[(Int, Long)]("outputFile2")){ outputBuffer => + val outMap = outputBuffer.toMap + "correctly count unique item sizes" in { + outMap(0) shouldBe 2 + outMap(1) shouldBe 2 + } + } + .run + .finish + } +} + class TypedPipeHashJoinJob(args: Args) extends Job(args) { TypedTsv[(Int, Int)]("inputFile0") .group diff --git a/scalding-core/src/test/scala/com/twitter/scalding/WordCountTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/WordCountTest.scala index 8f0f907c23..fec67ae0cf 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/WordCountTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/WordCountTest.scala @@ -23,7 +23,7 @@ class WordCountTest extends WordSpec with Matchers { .arg("input", "inputFile") .arg("output", "outputFile") .source(TextLine("inputFile"), List((0, "hack hack hack and hack"))) - .sink[(String, Int)](Tsv("outputFile")){ outputBuffer => + .sink[(String, Int)](TypedTsv[(String, Long)]("outputFile")){ outputBuffer => val outMap = outputBuffer.toMap "count words correctly" in { outMap("hack") shouldBe 4 diff --git a/scalding-core/src/test/scala/com/twitter/scalding/WrappedJoinerTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/WrappedJoinerTest.scala new file mode 100644 index 0000000000..b507ca6460 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/WrappedJoinerTest.scala @@ -0,0 +1,71 @@ +package com.twitter.scalding + +import cascading.flow.FlowException +import cascading.pipe.CoGroup +import cascading.pipe.joiner.{ JoinerClosure, InnerJoin } +import cascading.tuple.Tuple +import org.scalatest.{ Matchers, WordSpec } + +import java.util.{ Iterator => JIterator } + +class CheckFlowProcessJoiner(uniqueID: UniqueID) extends InnerJoin { + override def getIterator(joinerClosure: JoinerClosure): JIterator[Tuple] = { + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } + + super.getIterator(joinerClosure) + } +} + +class TestWrappedJoinerJob(args: Args) extends Job(args) { + val uniqueID = UniqueID.getIDFor(flowDef) + + val inA = Tsv("inputA", ('a, 'b)) + val inB = Tsv("inputB", ('x, 'y)) + + val joiner = { + val checkJoiner = new CheckFlowProcessJoiner(uniqueID) + if (args.boolean("wrapJoiner")) WrappedJoiner(checkJoiner) else checkJoiner + } + + val p1 = new CoGroup(inA, 'a, inB, 'x, joiner) + + // The .forceToDisk is necessary to have the test work properly. + p1.forceToDisk.write(Tsv("output")) +} + +class WrappedJoinerTest extends WordSpec with Matchers { + "Methods called from a Joiner" should { + "have access to a FlowProcess when WrappedJoiner is used" in { + JobTest(new TestWrappedJoinerJob(_)) + .arg("wrapJoiner", "true") + .source(Tsv("inputA"), Seq(("1", "alpha"), ("2", "beta"))) + .source(Tsv("inputB"), Seq(("1", "first"), ("2", "second"))) + .sink[(Int, String)](Tsv("output")) { outBuf => + // The job will fail with an exception if the FlowProcess is unavailable. + } + .runHadoop + .finish + } + + "have no access to a FlowProcess when WrappedJoiner is not used" in { + try { + JobTest(new TestWrappedJoinerJob(_)) + .source(Tsv("inputA"), Seq(("1", "alpha"), ("2", "beta"))) + .source(Tsv("inputB"), Seq(("1", "first"), ("2", "second"))) + .sink[(Int, String)](Tsv("output")) { outBuf => + // The job will fail with an exception if the FlowProcess is unavailable. + } + .runHadoop + .finish + + fail("The test Job without WrappedJoiner should fail.") + } catch { + case ex: FlowException => + ex.getCause.getMessage should include ("the FlowProcess for unique id") + } + } + } +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/filecache/DistributedCacheFileSpec.scala b/scalding-core/src/test/scala/com/twitter/scalding/filecache/DistributedCacheFileSpec.scala index 70aec69590..4bca622e3e 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/filecache/DistributedCacheFileSpec.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/filecache/DistributedCacheFileSpec.scala @@ -22,11 +22,11 @@ import java.net.URI import org.apache.hadoop.conf.Configuration import org.scalatest.{ Matchers, WordSpec } import scala.collection.mutable -/* -TODO: fix? is it worth having the dep on mockito just for this? + +// TODO: fix? is it worth having the dep on mockito just for this? class DistributedCacheFileSpec extends WordSpec with Matchers { case class UnknownMode(buffers: Map[Source, mutable.Buffer[Tuple]]) extends TestMode with CascadingLocal - + /* val conf = smartMock[Configuration] lazy val hdfsMode = { @@ -44,11 +44,11 @@ class DistributedCacheFileSpec extends WordSpec with Matchers { lazy val testMode = smartMock[Test] lazy val localMode = smartMock[Local] - +*/ val uriString = "hdfs://foo.example:1234/path/to/the/stuff/thefilename.blah" val uri = new URI(uriString) val hashHex = URIHasher(uri) - val hashedFilename = "thefilename.blah-" + hashHex + val hashedFilename = hashHex + "-thefilename.blah" "DistributedCacheFile" should { "symlinkNameFor must return a hashed name" in { @@ -56,6 +56,7 @@ class DistributedCacheFileSpec extends WordSpec with Matchers { } } + /* "UncachedFile.add" should { val dcf = new UncachedFile(Right(uri)) @@ -81,5 +82,5 @@ class DistributedCacheFileSpec extends WordSpec with Matchers { an[RuntimeException] should be thrownBy (dcf.add()(mode)) } } + */ } -*/ diff --git a/scalding-core/src/test/scala/com/twitter/scalding/mathematics/Matrix2Test.scala b/scalding-core/src/test/scala/com/twitter/scalding/mathematics/Matrix2Test.scala index f3faee0c1b..39e880b958 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/mathematics/Matrix2Test.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/mathematics/Matrix2Test.scala @@ -221,6 +221,29 @@ class Matrix2Cosine(args: Args) extends Job(args) { cosine.write(TypedTsv[(Int, Int, Double)]("cosine")) } +class Matrix2Normalize(args: Args) extends Job(args) { + + import Matrix2._ + import cascading.pipe.Pipe + import cascading.tuple.Fields + import com.twitter.scalding.TDsl._ + + val tp1 = TypedPipe.from(TypedTsv[(Int, Int, Double)]("mat1")) + val mat1 = MatrixLiteral(tp1, NoClue) + + // Now test for the case when value is Long type + val matL1Norm = mat1.rowL1Normalize + matL1Norm.write(TypedTsv[(Int, Int, Double)]("normalized")) + + //val p2: Pipe = Tsv("mat2", ('x2, 'y2, 'v2)).read // test Long type as value is OK + val tp2 = TypedPipe.from(TypedTsv[(Int, Int, Long)]("mat2")) + //val tp2 = p2.toTypedPipe[(Int, Int, Long)](('x2, 'y2, 'v2)) + val mat2 = MatrixLiteral(tp2, NoClue) + + val mat2L1Norm = mat2.rowL1Normalize + mat2L1Norm.write(TypedTsv[(Int, Int, Double)]("long_normalized")) +} + class Scalar2Ops(args: Args) extends Job(args) { import Matrix2._ @@ -433,6 +456,27 @@ class Matrix2Test extends WordSpec with Matchers { } } + "A Matrix2 Normalize job" should { + TUtil.printStack { + JobTest(new Matrix2Normalize(_)) + .source(TypedTsv[(Int, Int, Double)]("mat1"), List((1, 1, 4.0), (1, 2, 1.0), (2, 2, 1.0), (3, 1, 1.0), (3, 2, 3.0), (3, 3, 4.0))) + .source(TypedTsv[(Int, Int, Long)]("mat2"), List((1, 1, 4L), (1, 2, 1L), (2, 2, 1L), (3, 1, 1L), (3, 2, 3L), (3, 3, 4L))) + .typedSink(TypedTsv[(Int, Int, Double)]("normalized")) { ob => + "correctly compute l1 normalization for matrix with double values" in { + toSparseMat(ob) shouldBe Map((1, 1) -> 0.8, (1, 2) -> 0.2, (2, 2) -> 1.0, (3, 1) -> 0.125, (3, 2) -> 0.375, (3, 3) -> 0.5) + } + } + .typedSink(TypedTsv[(Int, Int, Double)]("long_normalized")){ ob => + "correctly compute l1 normalization for matrix with long values" in { + toSparseMat(ob) shouldBe Map((1, 1) -> 0.8, (1, 2) -> 0.2, (2, 2) -> 1.0, (3, 1) -> 0.125, (3, 2) -> 0.375, (3, 3) -> 0.5) + } + + } + .runHadoop + .finish + } + } + "A Matrix2 Scalar2Ops job" should { TUtil.printStack { JobTest(new Scalar2Ops(_)) diff --git a/scalding-core/src/test/scala/com/twitter/scalding/mathematics/SizeHintTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/mathematics/SizeHintTest.scala index 37e0e3f924..09e26128f3 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/mathematics/SizeHintTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/mathematics/SizeHintTest.scala @@ -25,7 +25,7 @@ import org.scalacheck.Gen._ object SizeHintProps extends Properties("SizeHint") { - val noClueGen = value(NoClue) + val noClueGen = const(NoClue) val finiteHintGen = for ( rows <- choose(-1L, 1000000L); diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/HashEqualsArrayWrapperTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/HashEqualsArrayWrapperTest.scala new file mode 100644 index 0000000000..f504e647b0 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/HashEqualsArrayWrapperTest.scala @@ -0,0 +1,72 @@ +package com.twitter.scalding.typed + +import org.scalacheck.{ Arbitrary, Prop } +import org.scalatest.PropSpec +import org.scalatest.prop.{ Checkers, PropertyChecks } + +object HashArrayEqualsWrapperLaws { + + def check2[T](ordToTest: Ordering[HashEqualsArrayWrapper[T]])(implicit ord: Ordering[T], arb: Arbitrary[Array[T]]): Prop = + + Prop.forAll { (left: Array[T], right: Array[T]) => + + val leftWrapped = HashEqualsArrayWrapper.wrap(left) + val rightWrapped = HashEqualsArrayWrapper.wrap(right) + + import scala.Ordering.Implicits.seqDerivedOrdering + + val slowOrd: Ordering[Seq[T]] = seqDerivedOrdering[Seq, T](ord) + + val cmp = ordToTest.compare(leftWrapped, rightWrapped) + + val lenCmp = java.lang.Integer.compare(leftWrapped.wrapped.length, rightWrapped.wrapped.length) + if (lenCmp != 0) { + cmp.signum == lenCmp.signum + } else { + cmp.signum == slowOrd.compare(leftWrapped.wrapped.toSeq, rightWrapped.wrapped.toSeq).signum + } + } + + def check[T](ordToTest: Ordering[Array[T]])(implicit ord: Ordering[T], arb: Arbitrary[Array[T]]): Prop = + + Prop.forAll { (left: Array[T], right: Array[T]) => + import scala.Ordering.Implicits.seqDerivedOrdering + + val slowOrd: Ordering[Seq[T]] = seqDerivedOrdering[Seq, T](ord) + + val cmp = ordToTest.compare(left, right) + + val lenCmp = java.lang.Integer.compare(left.length, right.length) + if (lenCmp != 0) { + cmp.signum == lenCmp.signum + } else { + cmp.signum == slowOrd.compare(left.toSeq, right.toSeq).signum + } + } +} + +class HashArrayEqualsWrapperTest extends PropSpec with PropertyChecks with Checkers { + + property("Specialized orderings obey all laws for Arrays") { + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.longArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.intArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.shortArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.charArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.byteArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.booleanArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.floatArrayOrd)) + check(HashArrayEqualsWrapperLaws.check(HashEqualsArrayWrapper.doubleArrayOrd)) + } + + property("Specialized orderings obey all laws for wrapped Arrays") { + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsLongOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsIntOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsShortOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsCharOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsByteOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsBooleanOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsFloatOrdering)) + check(HashArrayEqualsWrapperLaws.check2(HashEqualsArrayWrapper.hashEqualsDoubleOrdering)) + } + +} \ No newline at end of file diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala new file mode 100644 index 0000000000..a096222518 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala @@ -0,0 +1,94 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding + +import com.twitter.scalding.serialization.CascadingBinaryComparator +import com.twitter.scalding.serialization.OrderedSerialization +import com.twitter.scalding.serialization.StringOrderedSerialization + +import org.scalatest.{ Matchers, WordSpec } + +class NoOrderdSerJob(args: Args) extends Job(args) { + + override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + + TypedPipe.from(TypedTsv[(String, String)]("input")) + .group + .max + .write(TypedTsv[(String, String)]("output")) + + // This should fail + if (args.boolean("check")) { + CascadingBinaryComparator.checkForOrderedSerialization(flowDef).get + } +} + +class OrderdSerJob(args: Args) extends Job(args) { + + implicit def stringOS: OrderedSerialization[String] = new StringOrderedSerialization + + override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + + TypedPipe.from(TypedTsv[(String, String)]("input")) + .group + .max + .write(TypedTsv[(String, String)]("output")) + + // This should not fail + if (args.boolean("check")) { + CascadingBinaryComparator.checkForOrderedSerialization(flowDef).get + } +} + +class RequireOrderedSerializationTest extends WordSpec with Matchers { + "A NoOrderedSerJob" should { + // This should throw + "throw with --check" in { + an[Exception] should be thrownBy { + (new NoOrderdSerJob(Mode.putMode(Local(true), Args("--check")))) + } + } + "not throw without --check" in { + (new NoOrderdSerJob(Mode.putMode(Local(true), Args("")))) + } + // throw if we try to run in: + "throw when run" in { + an[Exception] should be thrownBy { + JobTest(new NoOrderdSerJob(_)) + .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("b", "b"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => () } + .run + .finish + } + } + } + "A OrderedSerJob" should { + "not throw with --check" in { + // This should not throw + val osj = (new OrderdSerJob(Mode.putMode(Local(true), Args("--check")))) + } + // throw if we try to run in: + "run" in { + JobTest(new OrderdSerJob(_)) + .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("a", "b"), ("b", "b"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => + outBuf.toSet shouldBe Set(("a", "b"), ("b", "b")) + } + .run + .finish + } + } +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/TypedPipeDiffTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/TypedPipeDiffTest.scala new file mode 100644 index 0000000000..dc203d7fc1 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/TypedPipeDiffTest.scala @@ -0,0 +1,219 @@ +package com.twitter.scalding.typed + +import java.io.File +import java.nio.file.Files + +import com.twitter.algebird.MapAlgebra +import com.twitter.scalding.{ Config, Local } +import org.scalacheck.{ Arbitrary, Prop } +import org.scalatest.prop.{ Checkers, PropertyChecks } +import org.scalatest.{ FunSuite, PropSpec } + +import scala.reflect.ClassTag + +class NoOrdering(val x: String) { + + override def equals(other: Any): Boolean = other match { + case that: NoOrdering => x.equals(that.x) + case _ => false + } + + override def hashCode(): Int = x.hashCode +} + +class NoOrderingHashCollisions(val x: String) { + + override def equals(other: Any): Boolean = other match { + case that: NoOrderingHashCollisions => x.equals(that.x) + case _ => false + } + + override def hashCode(): Int = 0 +} + +object TypedPipeRunner { + def runToList[T](output: TypedPipe[T]): List[T] = + output + .toIterableExecution + .waitFor(Config.default, Local(strictSources = true)) + .get + .toList +} + +class TypedPipeDiffTest extends FunSuite { + import com.twitter.scalding.typed.TypedPipeRunner._ + + val left = List("hi", "hi", "bye", "foo", "bar") + val right = List("hi", "bye", "foo", "baz") + val expectedSortedDiff = List(("bar", (1, 0)), ("baz", (0, 1)), ("hi", (2, 1))).sorted + + val leftArr = List( + Array[Byte](3, 3, 5, 3, 2), + Array[Byte](2, 2, 2), + Array[Byte](0, 1, 0)) + + val rightArr = List( + Array[Byte](2, 2, 2), + Array[Byte](2, 2, 2), + Array[Byte](3, 3, 5, 3, 2), + Array[Byte](0, 1, 1)) + + val expectedSortedArrDiff = List( + (Array[Byte](0, 1, 0).toSeq, (1, 0)), + (Array[Byte](0, 1, 1).toSeq, (0, 1)), + (Array[Byte](2, 2, 2).toSeq, (1, 2))) + + test("diff works for objects with ordering and good hashcodes") { + val pipe1 = TypedPipe.from(left) + val pipe2 = TypedPipe.from(right) + val diff = TypedPipeDiff.diff(pipe1, pipe2) + + assert(expectedSortedDiff === runToList(diff.toTypedPipe).sorted) + } + + // this lets us sort the results, + // without bringing an ordering into scope + private def sort(x: List[(Seq[Byte], (Long, Long))]): List[(Seq[Byte], (Long, Long))] = { + import scala.Ordering.Implicits.seqDerivedOrdering + x.sorted + } + + test("diffArrayPipes works for arrays") { + val pipe1 = TypedPipe.from(leftArr) + val pipe2 = TypedPipe.from(rightArr) + + val diff = TypedPipeDiff.diffArrayPipes(pipe1, pipe2).map { case (arr, counts) => (arr.toSeq, counts) } + + assert(expectedSortedArrDiff === sort(runToList(diff))) + } + + test("diffWithoutOrdering works for objects with ordering and good hashcodes") { + val pipe1 = TypedPipe.from(left) + val pipe2 = TypedPipe.from(right) + val diff = TypedPipeDiff.diffByHashCode(pipe1, pipe2) + + assert(expectedSortedDiff === runToList(diff).sorted) + } + + test("diffWithoutOrdering does not require ordering") { + val pipe1 = TypedPipe.from(left.map(new NoOrdering(_))) + val pipe2 = TypedPipe.from(right.map(new NoOrdering(_))) + val diff = TypedPipeDiff.diffByHashCode(pipe1, pipe2) + + assert(expectedSortedDiff === runToList(diff).map { case (nord, counts) => (nord.x, counts) }.sorted) + } + + test("diffWithoutOrdering works even with hash collisions") { + val pipe1 = TypedPipe.from(left.map(new NoOrderingHashCollisions(_))) + val pipe2 = TypedPipe.from(right.map(new NoOrderingHashCollisions(_))) + val diff = TypedPipeDiff.diffByHashCode(pipe1, pipe2) + assert(expectedSortedDiff === runToList(diff).map { case (nord, counts) => (nord.x, counts) }.sorted) + } + + test("diffArrayPipesWithoutOrdering works for arrays of objects with no ordering") { + val pipe1 = TypedPipe.from(leftArr.map { arr => arr.map { b => new NoOrdering(b.toString) } }) + val pipe2 = TypedPipe.from(rightArr.map { arr => arr.map { b => new NoOrdering(b.toString) } }) + val diff = TypedPipeDiff.diffArrayPipes(pipe1, pipe2) + + assert(expectedSortedArrDiff === sort(runToList(diff).map{ case (arr, counts) => (arr.map(_.x.toByte).toSeq, counts) })) + } + +} + +object TypedPipeDiffLaws { + import com.twitter.scalding.typed.TypedPipeDiff.Enrichments._ + import com.twitter.scalding.typed.TypedPipeRunner._ + + def checkDiff[T](left: List[T], right: List[T], diff: List[(T, (Long, Long))]): Boolean = { + val noDuplicates = diff.size == diff.map(_._1).toSet.size + val expected = MapAlgebra.sumByKey(left.map((_, (1L, 0L))).iterator ++ right.map((_, (0L, 1L))).iterator) + .filter { case (t, (rCount, lCount)) => rCount != lCount } + + noDuplicates && expected == diff.toMap + } + + def checkArrayDiff[T](left: List[Array[T]], right: List[Array[T]], diff: List[(Seq[T], (Long, Long))]): Boolean = { + checkDiff(left.map(_.toSeq), right.map(_.toSeq), diff) + } + + def diffLaw[T: Ordering: Arbitrary]: Prop = Prop.forAll { (left: List[T], right: List[T]) => + val diff = runToList(TypedPipe.from(left).diff(TypedPipe.from(right)).toTypedPipe) + checkDiff(left, right, diff) + } + + def diffArrayLaw[T](implicit arb: Arbitrary[List[Array[T]]], ct: ClassTag[T]): Prop = Prop.forAll { (left: List[Array[T]], right: List[Array[T]]) => + val diff = runToList(TypedPipe.from(left).diffArrayPipes(TypedPipe.from(right))) + .map { case (arr, counts) => (arr.toSeq, counts) } + checkArrayDiff(left, right, diff) + } + + def diffByGroupLaw[T: Arbitrary]: Prop = Prop.forAll { (left: List[T], right: List[T]) => + val diff = runToList(TypedPipe.from(left).diffByHashCode(TypedPipe.from(right))) + checkDiff(left, right, diff) + } + +} + +class TypedPipeDiffLaws extends PropSpec with PropertyChecks with Checkers { + + property("diffLaws") { + check(TypedPipeDiffLaws.diffLaw[Int]) + check(TypedPipeDiffLaws.diffLaw[String]) + } + + property("diffArrayLaws") { + + implicit val arbNoOrdering = Arbitrary { + for { + strs <- Arbitrary.arbitrary[Array[String]] + } yield { + strs.map { new NoOrdering(_) } + } + } + + implicit val arbNoOrderingHashCollision = Arbitrary { + for { + strs <- Arbitrary.arbitrary[Array[String]] + } yield { + strs.map { new NoOrderingHashCollisions(_) } + } + } + + check(TypedPipeDiffLaws.diffArrayLaw[Long]) + check(TypedPipeDiffLaws.diffArrayLaw[Int]) + check(TypedPipeDiffLaws.diffArrayLaw[Short]) + check(TypedPipeDiffLaws.diffArrayLaw[Char]) + check(TypedPipeDiffLaws.diffArrayLaw[Byte]) + check(TypedPipeDiffLaws.diffArrayLaw[Boolean]) + check(TypedPipeDiffLaws.diffArrayLaw[Float]) + check(TypedPipeDiffLaws.diffArrayLaw[Double]) + check(TypedPipeDiffLaws.diffArrayLaw[String]) + check(TypedPipeDiffLaws.diffArrayLaw[NoOrdering]) + check(TypedPipeDiffLaws.diffArrayLaw[NoOrderingHashCollisions]) + } + + property("diffByGroupLaws") { + + implicit val arbNoOrdering = Arbitrary { + for { + name <- Arbitrary.arbitrary[String] + } yield { + new NoOrdering(name) + } + } + + implicit val arbNoOrderingHashCollision = Arbitrary { + for { + name <- Arbitrary.arbitrary[String] + } yield { + new NoOrderingHashCollisions(name) + } + } + + check(TypedPipeDiffLaws.diffByGroupLaw[Int]) + check(TypedPipeDiffLaws.diffByGroupLaw[String]) + check(TypedPipeDiffLaws.diffByGroupLaw[NoOrdering]) + check(TypedPipeDiffLaws.diffByGroupLaw[NoOrderingHashCollisions]) + } + +} \ No newline at end of file diff --git a/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala b/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala index dabe312fa7..442862af86 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala @@ -15,8 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import java.util.Calendar -import java.util.Date import java.util.TimeZone import java.text.SimpleDateFormat @@ -53,10 +51,13 @@ object DateOps extends java.io.Serializable { /** * Return the guessed format for this datestring */ - def getFormat(s: String): Option[String] = { - DATE_FORMAT_VALIDATORS.find{ _._2.findFirstIn(prepare(s)).isDefined }.map(_._1) - } + def getFormat(s: String): Option[String] = + DATE_FORMAT_VALIDATORS.find { _._2.findFirstIn(prepare(s)).isDefined }.map(_._1) + /** + * The DateParser returned here is based on SimpleDateFormat, which is not thread-safe. + * Do not share the result across threads. + */ def getDateParser(s: String): Option[DateParser] = getFormat(s).map { fmt => DateParser.from(new SimpleDateFormat(fmt)).contramap(prepare) } } diff --git a/scalding-date/src/main/scala/com/twitter/scalding/DateParser.scala b/scalding-date/src/main/scala/com/twitter/scalding/DateParser.scala index 2236cfa060..5a20c344b6 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/DateParser.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/DateParser.scala @@ -41,8 +41,11 @@ object DateParser { /** * This is scalding's default date parser. You can choose this * by setting an implicit val DateParser. + * Note that DateParsers using SimpleDateFormat from Java are + * not thread-safe, thus the def here. You can cache the result + * if you are sure */ - val default: DateParser = new DateParser { + def default: DateParser = new DateParser { def parse(s: String)(implicit tz: TimeZone) = DateOps.getDateParser(s) .map { p => p.parse(s) } @@ -56,6 +59,10 @@ object DateParser { /** Using the type-class pattern */ def parse(s: String)(implicit tz: TimeZone, p: DateParser): Try[RichDate] = p.parse(s)(tz) + /** + * Note that DateFormats in Java are generally not thread-safe, + * so you should not share the result here across threads + */ implicit def from(df: DateFormat): DateParser = new DateParser { def parse(s: String)(implicit tz: TimeZone) = Try { df.setTimeZone(tz) @@ -63,6 +70,9 @@ object DateParser { } } + /** + * This ignores the time-zone assuming it must be in the String + */ def from(fn: String => RichDate) = new DateParser { def parse(s: String)(implicit tz: TimeZone) = Try(fn(s)) } diff --git a/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala b/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala index d8b0abf633..44d235e935 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala @@ -56,6 +56,13 @@ object DateRange extends java.io.Serializable { case Seq(o) => parse(o) case x => sys.error("--date must have exactly one or two date[time]s. Got: " + x.toString) } + + /** + * DateRanges are inclusive. Use this to create a DateRange that excludes + * the last millisecond from the second argument. + */ + def exclusiveUpper(include: RichDate, exclude: RichDate): DateRange = + DateRange(include, exclude - Millisecs(1)) } /** diff --git a/scalding-date/src/main/scala/com/twitter/scalding/RichDate.scala b/scalding-date/src/main/scala/com/twitter/scalding/RichDate.scala index a06a91d815..ebe1c77c68 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/RichDate.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/RichDate.scala @@ -64,7 +64,8 @@ object RichDate { } /** - * A value class wrapper for milliseconds since the epoch + * A value class wrapper for milliseconds since the epoch. Its tempting to extend + * this with AnyVal but this causes problem with Java code. */ case class RichDate(val timestamp: Long) extends Ordered[RichDate] { // these are mutable, don't keep them around @@ -77,7 +78,7 @@ case class RichDate(val timestamp: Long) extends Ordered[RichDate] { def -(that: RichDate) = AbsoluteDuration.fromMillisecs(timestamp - that.timestamp) override def compare(that: RichDate): Int = - Ordering[Long].compare(timestamp, that.timestamp) + java.lang.Long.compare(timestamp, that.timestamp) //True of the other is a RichDate with equal value, or a Date equal to value override def equals(that: Any) = diff --git a/scalding-date/src/test/scala/com/twitter/scalding/DateProperties.scala b/scalding-date/src/test/scala/com/twitter/scalding/DateProperties.scala index 1e504e567c..3380d6f42d 100644 --- a/scalding-date/src/test/scala/com/twitter/scalding/DateProperties.scala +++ b/scalding-date/src/test/scala/com/twitter/scalding/DateProperties.scala @@ -21,7 +21,6 @@ import org.scalacheck.Prop.forAll import org.scalacheck.Gen.choose import org.scalacheck.Prop._ -import scala.util.control.Exception.allCatch import AbsoluteDuration.fromMillisecs object DateProperties extends Properties("Date Properties") { @@ -113,6 +112,17 @@ object DateProperties extends Properties("Date Properties") { dr.start + dr.length - AbsoluteDuration.fromMillisecs(1L) == dr.end } + property("DateRange.exclusiveUpper works") = forAll { (a: RichDate, b: RichDate) => + val lower = Ordering[RichDate].min(a, b) + val upper = Ordering[RichDate].max(a, b) + val ex = DateRange.exclusiveUpper(lower, upper) + val in = DateRange(lower, upper) + val upperPred = upper - Millisecs(1) + + (false == ex.contains(upper)) && + (ex.contains(upperPred) || (lower == upper)) + } + def toRegex(glob: String) = (glob.flatMap { c => if (c == '*') ".*" else c.toString }).r def matches(l: List[String], arg: String): Int = l diff --git a/scalding-date/src/test/scala/com/twitter/scalding/GlobifierOps.scala b/scalding-date/src/test/scala/com/twitter/scalding/GlobifierOps.scala index 4cd4291c0a..55fca6a397 100644 --- a/scalding-date/src/test/scala/com/twitter/scalding/GlobifierOps.scala +++ b/scalding-date/src/test/scala/com/twitter/scalding/GlobifierOps.scala @@ -15,7 +15,6 @@ limitations under the License. */ package com.twitter.scalding -import com.twitter.scalding._ import java.util.TimeZone import scala.util.{ Try, Success, Failure } diff --git a/scalding-date/src/test/scala/com/twitter/scalding/GlobifierProperties.scala b/scalding-date/src/test/scala/com/twitter/scalding/GlobifierProperties.scala index 4724f29823..e79f6dc3dc 100644 --- a/scalding-date/src/test/scala/com/twitter/scalding/GlobifierProperties.scala +++ b/scalding-date/src/test/scala/com/twitter/scalding/GlobifierProperties.scala @@ -21,8 +21,6 @@ import org.scalacheck.Prop.forAll import org.scalacheck.Gen.choose import org.scalacheck.Prop._ -import scala.util.control.Exception.allCatch -import AbsoluteDuration.fromMillisecs import java.util.TimeZone object GlobifierProperties extends Properties("Globifier Properties") { diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala index a5de9e29fd..c0d527eb85 100644 --- a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala @@ -22,8 +22,6 @@ import java.io.{ BufferedWriter, File, FileWriter } import org.apache.hadoop.mapred.JobConf -import scala.collection.mutable.Buffer - import org.slf4j.LoggerFactory /** @@ -90,6 +88,7 @@ case class HadoopPlatformJobTest( } def run { + System.setProperty("cascading.update.skip", "true") val job = initJob(cons) cluster.addClassSourceToClassPath(cons.getClass) cluster.addClassSourceToClassPath(job.getClass) diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopSharedPlatformTest.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopSharedPlatformTest.scala new file mode 100644 index 0000000000..325f83858a --- /dev/null +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopSharedPlatformTest.scala @@ -0,0 +1,49 @@ +/* +Copyright 2014 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.platform + +import org.scalatest.{ BeforeAndAfterAll, Suite } + +trait HadoopSharedPlatformTest extends BeforeAndAfterAll { this: Suite => + org.apache.log4j.Logger.getLogger("org.apache.hadoop").setLevel(org.apache.log4j.Level.ERROR) + org.apache.log4j.Logger.getLogger("org.mortbay").setLevel(org.apache.log4j.Level.ERROR) + org.apache.log4j.Logger.getLogger("org.apache.hadoop.metrics2.util").setLevel(org.apache.log4j.Level.ERROR) + + val cluster = LocalCluster() + + def initialize() = cluster.initialize() + + override def beforeAll() { + cluster.synchronized { + initialize() + } + super.beforeAll() + } + + //TODO is there a way to buffer such that we see test results AFTER afterEach? Otherwise the results + // get lost in the logging + override def afterAll() { + try super.afterAll() + finally { + // Necessary because afterAll can be called from a different thread and we want to make sure that the state + // is visible. Note that this assumes there is no contention for LocalCluster (which LocalCluster ensures), + // otherwise there could be deadlock. + cluster.synchronized { + cluster.shutdown() + } + } + } +} diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/LocalCluster.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/LocalCluster.scala index 92ce1c6fa1..bb83b6aaf1 100644 --- a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/LocalCluster.scala +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/LocalCluster.scala @@ -17,24 +17,12 @@ package com.twitter.scalding.platform import com.twitter.scalding._ -import java.io.{ - BufferedInputStream, - BufferedReader, - BufferedWriter, - DataInputStream, - DataOutputStream, - File, - FileInputStream, - FileOutputStream, - FileReader, - FileWriter, - RandomAccessFile -} +import java.io.{ File, RandomAccessFile } import java.nio.channels.FileLock import org.apache.hadoop.conf.Configuration import org.apache.hadoop.filecache.DistributedCache -import org.apache.hadoop.fs.{ FileSystem, FileUtil, Path } +import org.apache.hadoop.fs.{ FileUtil, Path } import org.apache.hadoop.hdfs.MiniDFSCluster import org.apache.hadoop.mapred.{ JobConf, MiniMRCluster } import org.slf4j.LoggerFactory @@ -131,6 +119,7 @@ class LocalCluster(mutex: Boolean = true) { classOf[com.twitter.scalding.RichDate], classOf[cascading.tuple.TupleException], classOf[com.twitter.chill.Externalizer[_]], + classOf[com.twitter.chill.algebird.AveragedValueSerializer], classOf[com.twitter.algebird.Semigroup[_]], classOf[com.twitter.chill.KryoInstantiator], classOf[org.jgrapht.ext.EdgeNameProvider[_]], diff --git a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/JoinerFlowProcessTest.scala b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/JoinerFlowProcessTest.scala new file mode 100644 index 0000000000..cb2ea779e9 --- /dev/null +++ b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/JoinerFlowProcessTest.scala @@ -0,0 +1,93 @@ +package com.twitter.scalding + +import cascading.pipe.joiner.{ JoinerClosure, InnerJoin } +import cascading.tuple.Tuple +import com.twitter.scalding.platform.{ HadoopSharedPlatformTest, HadoopPlatformJobTest, HadoopPlatformTest } +import org.scalatest.{ Matchers, WordSpec } + +import java.util.{ Iterator => JIterator } + +import org.slf4j.{ LoggerFactory, Logger } + +class CheckFlowProcessJoiner(uniqueID: UniqueID) extends InnerJoin { + override def getIterator(joinerClosure: JoinerClosure): JIterator[Tuple] = { + println("CheckFlowProcessJoiner.getItertor") + + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } + + super.getIterator(joinerClosure) + } +} + +class CheckForFlowProcessInFieldsJob(args: Args) extends Job(args) { + val uniqueID = UniqueID.getIDFor(flowDef) + val stat = Stat("joins") + + val inA = Tsv("inputA", ('a, 'b)) + val inB = Tsv("inputB", ('x, 'y)) + + val p = inA.joinWithSmaller('a -> 'x, inB).map(('b, 'y) -> 'z) { args: (String, String) => + stat.inc + + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } + + s"${args._1},${args._2}" + } + + p.write(Tsv("output", ('b, 'y))) +} + +class CheckForFlowProcessInTypedJob(args: Args) extends Job(args) { + val uniqueID = UniqueID.getIDFor(flowDef) + val stat = Stat("joins") + + val inA = TypedPipe.from(TypedTsv[(String, String)]("inputA")) + val inB = TypedPipe.from(TypedTsv[(String, String)]("inputB")) + + inA.group.join(inB.group).forceToReducers.mapGroup((key, valuesIter) => { + stat.inc + + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } + + valuesIter.map({ case (a, b) => s"$a:$b" }) + }).toTypedPipe.write(TypedTsv[(String, String)]("output")) +} + +class JoinerFlowProcessTest extends WordSpec with Matchers with HadoopSharedPlatformTest { + "Methods called from a Joiner" should { + "have access to a FlowProcess from a join in the Fields-based API" in { + HadoopPlatformJobTest(new CheckForFlowProcessInFieldsJob(_), cluster) + .source(TypedTsv[(String, String)]("inputA"), Seq(("1", "alpha"), ("2", "beta"))) + .source(TypedTsv[(String, String)]("inputB"), Seq(("1", "first"), ("2", "second"))) + .sink(TypedTsv[(String, String)]("output")) { _ => + // The job will fail with an exception if the FlowProcess is unavailable. + } + .inspectCompletedFlow({ flow => + flow.getFlowStats.getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 + }) + .run + } + + "have access to a FlowProcess from a join in the Typed API" in { + HadoopPlatformJobTest(new CheckForFlowProcessInTypedJob(_), cluster) + .source(TypedTsv[(String, String)]("inputA"), Seq(("1", "alpha"), ("2", "beta"))) + .source(TypedTsv[(String, String)]("inputB"), Seq(("1", "first"), ("2", "second"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { _ => + // The job will fail with an exception if the FlowProcess is unavailable. + } + .inspectCompletedFlow({ flow => + flow.getFlowStats.getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 + }) + .run + } + } +} diff --git a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/ordered_serialization/OrderedSerializationTest.scala b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/ordered_serialization/OrderedSerializationTest.scala new file mode 100644 index 0000000000..e06872c470 --- /dev/null +++ b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/ordered_serialization/OrderedSerializationTest.scala @@ -0,0 +1,60 @@ +package com.twitter.scalding.ordered_serialization + +import com.twitter.scalding._ +import com.twitter.scalding.platform.{ HadoopPlatformJobTest, HadoopPlatformTest } +import com.twitter.scalding.serialization.OrderedSerialization + +import org.scalacheck.{ Arbitrary, Gen } +import org.scalatest.FunSuite + +import scala.language.experimental.macros +import scala.math.Ordering + +object OrderedSerializationTest { + implicit val genASGK = Arbitrary { + for { + ts <- Arbitrary.arbitrary[Long] + b <- Gen.nonEmptyListOf(Gen.alphaNumChar).map (_.mkString) + } yield NestedCaseClass(RichDate(ts), (b, b)) + } + + def sample[T: Arbitrary]: T = Arbitrary.arbitrary[T].sample.get + val data = sample[List[NestedCaseClass]].take(1000) +} + +case class NestedCaseClass(day: RichDate, key: (String, String)) + +class OrderedSerializationTest extends FunSuite with HadoopPlatformTest { + import OrderedSerializationTest._ + test("A test job with a fork and join, had previously not had boxed serializations on all branches") { + val fn = (arg: Args) => new ComplexJob(data, arg) + HadoopPlatformJobTest(fn, cluster) + .arg("output1", "output1") + .arg("output2", "output2") + // Here we are just testing that we hit no exceptions in the course of this run + // the previous issue would have caused OOM or other exceptions. If we get to the end + // then we are good. + .sink[String](TypedTsv[String]("output2")) { x => () } + .sink[String](TypedTsv[String]("output1")) { x => () } + .run + } +} + +class ComplexJob(input: List[NestedCaseClass], args: Args) extends Job(args) { + implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = macro com.twitter.scalding.serialization.macros.impl.OrderedSerializationProviderImpl[T] + + val ds1 = TypedPipe.from(input).map(_ -> 1L).distinct.group + + val ds2 = TypedPipe.from(input).map(_ -> 1L).distinct.group + + ds2 + .keys + .map(s => s.toString) + .write(TypedTsv[String](args("output1"))) + + ds2.join(ds1) + .values + .map(_.toString) + .write(TypedTsv[String](args("output2"))) +} + diff --git a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala index 3b3d05e7f8..07a790ed93 100644 --- a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala +++ b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala @@ -69,7 +69,7 @@ class TsvNoCacheJob(args: Args) extends Job(args) { // Keeping all of the specifications in the same tests puts the result output all together at the end. // This is useful given that the Hadoop MiniMRCluster and MiniDFSCluster spew a ton of logging. -class PlatformTests extends WordSpec with Matchers with HadoopPlatformTest { +class PlatformTests extends WordSpec with Matchers with HadoopSharedPlatformTest { org.apache.log4j.Logger.getLogger("org.apache.hadoop").setLevel(org.apache.log4j.Level.ERROR) org.apache.log4j.Logger.getLogger("org.mortbay").setLevel(org.apache.log4j.Level.ERROR) @@ -153,3 +153,51 @@ class IterableSourceDistinctTest extends WordSpec with Matchers with HadoopPlatf } } } + +object MultipleGroupByJobData { + val data: List[String] = { + val rnd = new scala.util.Random(22) + (0 until 20).map { _ => rnd.nextLong.toString }.toList + }.distinct +} + +class MultipleGroupByJob(args: Args) extends Job(args) { + import com.twitter.scalding.serialization._ + import MultipleGroupByJobData._ + implicit val stringOrdSer = new StringOrderedSerialization() + implicit val stringTup2OrdSer = new OrderedSerialization2(stringOrdSer, stringOrdSer) + val otherStream = TypedPipe.from(data).map{ k => (k, k) }.group + + TypedPipe.from(data) + .map{ k => (k, 1L) } + .group[String, Long](implicitly, stringOrdSer) + .sum + .map { + case (k, _) => + ((k, k), 1L) + } + .sumByKey[(String, String), Long](implicitly, stringTup2OrdSer, implicitly) + .map(_._1._1) + .map { t => + (t.toString, t) + } + .group + .leftJoin(otherStream) + .map(_._1) + .write(TypedTsv("output")) + +} + +class MultipleGroupByJobTest extends WordSpec with Matchers with HadoopPlatformTest { + "A grouped job" should { + import MultipleGroupByJobData._ + + "do some ops and not stamp on each other ordered serializations" in { + HadoopPlatformJobTest(new MultipleGroupByJob(_), cluster) + .source[String]("input", data) + .sink[String]("output") { _.toSet shouldBe data.map(_.toString).toSet } + .run + } + + } +} diff --git a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorTest.scala b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorTest.scala index 4975d6d2e0..3d8833433b 100644 --- a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorTest.scala +++ b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorTest.scala @@ -1,8 +1,7 @@ package com.twitter.scalding.reducer_estimation import com.twitter.scalding._ -import com.twitter.scalding.platform.{ HadoopPlatformJobTest, HadoopPlatformTest, LocalCluster } -import org.apache.hadoop.mapred.JobConf +import com.twitter.scalding.platform.{ HadoopPlatformJobTest, HadoopPlatformTest } import org.scalatest.{ Matchers, WordSpec } import scala.collection.JavaConverters._ @@ -10,8 +9,8 @@ object HipJob { val inSrc = TextLine(getClass.getResource("/hipster.txt").toString) val inScores = TypedTsv[(String, Double)](getClass.getResource("/scores.tsv").toString) val out = TypedTsv[Double]("output") - val countsPath = "counts.tsv" - val counts = TypedTsv[(String, Int)](countsPath) + val counts = TypedTsv[(String, Int)]("counts.tsv") + val size = TypedTsv[Long]("size.tsv") val correct = Map("hello" -> 1, "goodbye" -> 1, "world" -> 2) } @@ -33,12 +32,11 @@ class HipJob(args: Args) extends Job(args) { wordCounts.leftJoin(scores) .mapValues{ case (count, score) => (count, score.getOrElse(0.0)) } - - // force another M/R step + // force another M/R step - should use reducer estimation .toTypedPipe .map{ case (word, (count, score)) => (count, score) } .group.sum - + // force another M/R step - this should force 1 reducer because it is essentially a groupAll .toTypedPipe.values.sum .write(out) @@ -50,10 +48,30 @@ class SimpleJob(args: Args) extends Job(args) { .flatMap(_.split("[^\\w]+")) .map(_.toLowerCase -> 1) .group + // force the number of reducers to two, to test with/without estimation + .withReducers(2) .sum .write(counts) } +class GroupAllJob(args: Args) extends Job(args) { + import HipJob._ + TypedPipe.from(inSrc) + .flatMap(_.split("[^\\w]+")) + .groupAll + .size + .values + .write(size) +} + +class SimpleMapOnlyJob(args: Args) extends Job(args) { + import HipJob._ + // simple job with no reduce phase + TypedPipe.from(inSrc) + .flatMap(_.split("[^\\w]+")) + .write(TypedTsv[String]("mapped_output")) +} + class ReducerEstimatorTestSingle extends WordSpec with Matchers with HadoopPlatformTest { import HipJob._ @@ -61,6 +79,29 @@ class ReducerEstimatorTestSingle extends WordSpec with Matchers with HadoopPlatf .addReducerEstimator(classOf[InputSizeReducerEstimator]) + (InputSizeReducerEstimator.BytesPerReducer -> (1L << 10).toString)) + "Single-step job with reducer estimator" should { + "run with correct number of reducers" in { + HadoopPlatformJobTest(new SimpleJob(_), cluster) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 1 + + val conf = Config.fromHadoop(steps.head.getConfig) + conf.getNumReducers should contain (2) + } + .run + } + } +} + +class ReducerEstimatorTestSingleOverride extends WordSpec with Matchers with HadoopPlatformTest { + import HipJob._ + + override def initialize() = cluster.initialize(Config.empty + .addReducerEstimator(classOf[InputSizeReducerEstimator]) + + (InputSizeReducerEstimator.BytesPerReducer -> (1L << 10).toString) + + (Config.ReducerEstimatorOverride -> "true")) + "Single-step job with reducer estimator" should { "run with correct number of reducers" in { HadoopPlatformJobTest(new SimpleJob(_), cluster) @@ -75,12 +116,35 @@ class ReducerEstimatorTestSingle extends WordSpec with Matchers with HadoopPlatf } } } + +class ReducerEstimatorTestGroupAll extends WordSpec with Matchers with HadoopPlatformTest { + import HipJob._ + + override def initialize() = cluster.initialize(Config.empty + .addReducerEstimator(classOf[InputSizeReducerEstimator]) + + (InputSizeReducerEstimator.BytesPerReducer -> (1L << 10).toString)) + + "Group-all job with reducer estimator" should { + "run with correct number of reducers (i.e. 1)" in { + HadoopPlatformJobTest(new GroupAllJob(_), cluster) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 1 + + val conf = Config.fromHadoop(steps.head.getConfig) + conf.getNumReducers should contain (1) + } + .run + } + } +} + class ReducerEstimatorTestMulti extends WordSpec with Matchers with HadoopPlatformTest { import HipJob._ override def initialize() = cluster.initialize(Config.empty .addReducerEstimator(classOf[InputSizeReducerEstimator]) + - (InputSizeReducerEstimator.BytesPerReducer -> (1L << 16).toString)) + (InputSizeReducerEstimator.BytesPerReducer -> (1L << 10).toString)) "Multi-step job with reducer estimator" should { "run with correct number of reducers in each step" in { @@ -89,9 +153,32 @@ class ReducerEstimatorTestMulti extends WordSpec with Matchers with HadoopPlatfo .inspectCompletedFlow { flow => val steps = flow.getFlowSteps.asScala val reducers = steps.map(_.getConfig.getInt(Config.HadoopNumReducers, 0)).toList - reducers shouldBe List(1, 1, 2) + reducers shouldBe List(3, 1, 1) + } + .run + } + } +} + +class ReducerEstimatorTestMapOnly extends WordSpec with Matchers with HadoopPlatformTest { + import HipJob._ + + override def initialize() = cluster.initialize(Config.empty + .addReducerEstimator(classOf[InputSizeReducerEstimator]) + + (InputSizeReducerEstimator.BytesPerReducer -> (1L << 10).toString)) + + "Map-only job with reducer estimator" should { + "not set num reducers" in { + HadoopPlatformJobTest(new SimpleMapOnlyJob(_), cluster) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 1 + + val conf = Config.fromHadoop(steps.head.getConfig) + conf.getNumReducers should contain (0) } .run } } } + diff --git a/scalding-json/src/main/scala/com/twitter/scalding/TypedJson.scala b/scalding-json/src/main/scala/com/twitter/scalding/TypedJson.scala index 9685059276..1dde3d1648 100644 --- a/scalding-json/src/main/scala/com/twitter/scalding/TypedJson.scala +++ b/scalding-json/src/main/scala/com/twitter/scalding/TypedJson.scala @@ -1,13 +1,10 @@ package com.twitter.scalding -import com.twitter.bijection._ import com.twitter.bijection.{ Injection, AbstractInjection } import com.twitter.bijection.Inversion._ -import com.twitter.scalding._ import com.twitter.elephantbird.cascading2.scheme.LzoTextLine import org.json4s._ -import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization._ import org.json4s.{ NoTypeHints, native } diff --git a/scalding-macros/src/main/scala/com/twitter/scalding/macros/Macros.scala b/scalding-macros/src/main/scala/com/twitter/scalding/macros/Macros.scala index 465a27ae5c..cd22bc0816 100644 --- a/scalding-macros/src/main/scala/com/twitter/scalding/macros/Macros.scala +++ b/scalding-macros/src/main/scala/com/twitter/scalding/macros/Macros.scala @@ -20,6 +20,7 @@ import scala.language.experimental.macros import com.twitter.scalding._ import com.twitter.scalding.macros.impl._ import cascading.tuple.Fields +import com.twitter.scalding.serialization.OrderedSerialization object Macros { @@ -44,5 +45,4 @@ object Macros { def caseClassTypeDescriptor[T]: TypeDescriptor[T] = macro TypeDescriptorProviderImpl.caseClassTypeDescriptorImpl[T] def caseClassTypeDescriptorWithUnknown[T]: TypeDescriptor[T] = macro TypeDescriptorProviderImpl.caseClassTypeDescriptorWithUnknownImpl[T] - } diff --git a/scalding-macros/src/test/scala/com/twitter/scalding/macros/MacrosUnitTests.scala b/scalding-macros/src/test/scala/com/twitter/scalding/macros/MacrosUnitTests.scala index 7fdb83c9ae..aad48410ce 100644 --- a/scalding-macros/src/test/scala/com/twitter/scalding/macros/MacrosUnitTests.scala +++ b/scalding-macros/src/test/scala/com/twitter/scalding/macros/MacrosUnitTests.scala @@ -18,7 +18,7 @@ package com.twitter.scalding.macros import cascading.tuple.{ Tuple => CTuple, TupleEntry } import org.scalatest.{ Matchers, WordSpec } - +import scala.language.experimental.{ macros => smacros } import com.twitter.scalding._ import com.twitter.scalding.macros._ import com.twitter.scalding.macros.impl._ diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/TypedParquet.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/TypedParquet.scala new file mode 100644 index 0000000000..860c3cb71d --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/TypedParquet.scala @@ -0,0 +1,112 @@ +package com.twitter.scalding.parquet.tuple + +import _root_.parquet.filter2.predicate.FilterPredicate +import cascading.scheme.Scheme +import com.twitter.scalding._ +import com.twitter.scalding.parquet.HasFilterPredicate +import com.twitter.scalding.parquet.tuple.scheme.{ ParquetWriteSupport, ParquetReadSupport, TypedParquetTupleScheme } + +import scala.reflect.ClassTag + +/** + * Typed parquet tuple + * @author Jian Tang + */ +object TypedParquet { + /** + * Create readable typed parquet source. + * Here is an example: + * + * case class SampleClassB(string: String, int: Int, double: Option[Double], a: SampleClassA) + * + * class ReadSupport extends ParquetReadSupport[SampleClassB] { + * import com.twitter.scalding.parquet.tuple.macros.Macros._ + * override val tupleConverter: ParquetTupleConverter[SampleClassB] = caseClassParquetTupleConverter[SampleClassB] + * override val rootSchema: String = caseClassParquetSchema[SampleClassB] + * } + * + * val parquetTuple = TypedParquet[SampleClassB, ReadSupport](Seq(outputPath)) + * + * @param paths paths of parquet I/O + * @param t Read support type tag + * @tparam T Tuple type + * @tparam R Read support type + * @return a typed parquet source. + */ + def apply[T, R <: ParquetReadSupport[T]](paths: Seq[String])(implicit t: ClassTag[R]) = + new TypedFixedPathParquetTuple[T, R, ParquetWriteSupport[T]](paths, t.runtimeClass.asInstanceOf[Class[R]], null) + + /** + * Create readable typed parquet source with filter predicate. + */ + def apply[T, R <: ParquetReadSupport[T]](paths: Seq[String], fp: Option[FilterPredicate])(implicit t: ClassTag[R]) = + new TypedFixedPathParquetTuple[T, R, ParquetWriteSupport[T]](paths, t.runtimeClass.asInstanceOf[Class[R]], null) { + override def withFilter = fp + } + + /** + * Create typed parquet source supports both R/W. + * @param paths paths of parquet I/O + * @param r Read support type tag + * @param w Write support type tag + * @tparam T Tuple type + * @tparam R Read support type + * @return a typed parquet source. + */ + def apply[T, R <: ParquetReadSupport[T], W <: ParquetWriteSupport[T]](paths: Seq[String])(implicit r: ClassTag[R], + w: ClassTag[W]) = { + val readSupport = r.runtimeClass.asInstanceOf[Class[R]] + val writeSupport = w.runtimeClass.asInstanceOf[Class[W]] + new TypedFixedPathParquetTuple[T, R, W](paths, readSupport, writeSupport) + } + +} + +object TypedParquetSink { + /** + * Create typed parquet sink. + * Here is an example: + * + * case class SampleClassB(string: String, int: Int, double: Option[Double], a: SampleClassA) + * + * class WriteSupport extends ParquetWriteSupport[SampleClassB] { + * import com.twitter.scalding.parquet.tuple.macros.Macros._ + * + * override def writeRecord(r: SampleClassB, rc: RecordConsumer, schema: MessageType): Unit = + * caseClassWriteSupport[SampleClassB](r, rc, schema) + * override val rootSchema: String = caseClassParquetSchema[SampleClassB] + * } + * + * val sink = TypedParquetSink[SampleClassB, WriteSupport](Seq(outputPath)) + * + * @param paths paths of parquet I/O + * @param t Read support type tag + * @tparam T Tuple type + * @tparam W Write support type + * @return a typed parquet source. + */ + def apply[T, W <: ParquetWriteSupport[T]](paths: Seq[String])(implicit t: ClassTag[W]) = + new TypedFixedPathParquetTuple[T, ParquetReadSupport[T], W](paths, null, t.runtimeClass.asInstanceOf[Class[W]]) +} + +/** + * Typed Parquet tuple source/sink. + */ +trait TypedParquet[T, R <: ParquetReadSupport[T], W <: ParquetWriteSupport[T]] extends FileSource with Mappable[T] + with TypedSink[T] with HasFilterPredicate { + + val readSupport: Class[R] + val writeSupport: Class[W] + + override def converter[U >: T] = TupleConverter.asSuperConverter[T, U](TupleConverter.singleConverter[T]) + + override def setter[U <: T] = TupleSetter.asSubSetter[T, U](TupleSetter.singleSetter[T]) + + override def hdfsScheme = { + val scheme = new TypedParquetTupleScheme[T](readSupport, writeSupport, withFilter) + HadoopSchemeInstance(scheme.asInstanceOf[Scheme[_, _, _, _, _]]) + } +} + +class TypedFixedPathParquetTuple[T, R <: ParquetReadSupport[T], W <: ParquetWriteSupport[T]](val paths: Seq[String], + val readSupport: Class[R], val writeSupport: Class[W]) extends FixedPathSource(paths: _*) with TypedParquet[T, R, W] diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/Macros.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/Macros.scala new file mode 100644 index 0000000000..295fd5ccaa --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/Macros.scala @@ -0,0 +1,54 @@ +package com.twitter.scalding.parquet.tuple.macros + +import com.twitter.scalding.parquet.tuple.macros.impl.{ ParquetSchemaProvider, ParquetTupleConverterProvider, WriteSupportProvider } +import com.twitter.scalding.parquet.tuple.scheme.ParquetTupleConverter +import parquet.io.api.RecordConsumer +import parquet.schema.MessageType + +import scala.language.experimental.macros + +/** + * Macros used to generate parquet tuple read/write support. + * These macros support only case class that contains primitive fields or nested case classes and also collection fields + * like scala List, Set, and Map. + * @author Jian TANG + */ +object Macros { + /** + * Macro used to generate parquet schema for a given case class. For example if we have: + * + * case class SampleClassA(x: Int, y: String) + * case class SampleClassB(a: SampleClassA, y: String) + * + * The macro will generate a parquet message type like this: + * + * """ + * message SampleClassB { + * required group a { + * required int32 x; + * required binary y; + * } + * required binary y; + * } + * """ + * + * @tparam T Case class type that contains primitive fields or collection fields or nested case class. + * @return Generated case class parquet message type string + */ + def caseClassParquetSchema[T]: String = macro ParquetSchemaProvider.toParquetSchemaImpl[T] + + /** + * Macro used to generate parquet tuple converter for a given case class. + * + * @tparam T Case class type that contains primitive or collection type fields or nested case class. + * @return Generated parquet converter + */ + def caseClassParquetTupleConverter[T]: ParquetTupleConverter[T] = macro ParquetTupleConverterProvider.toParquetTupleConverterImpl[T] + + /** + * Macro used to generate case class write support to parquet. + * @tparam T User defined case class tuple type. + * @return Generated case class tuple write support function. + */ + def caseClassWriteSupport[T]: (T, RecordConsumer, MessageType) => Unit = macro WriteSupportProvider.toWriteSupportImpl[T] +} diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala new file mode 100644 index 0000000000..012bd36d22 --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala @@ -0,0 +1,81 @@ +package com.twitter.scalding.parquet.tuple.macros.impl + +import com.twitter.bijection.macros.impl.IsCaseClassImpl + +import scala.reflect.macros.Context + +object ParquetSchemaProvider { + def toParquetSchemaImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[String] = { + import c.universe._ + + if (!IsCaseClassImpl.isCaseClassType(c)(T.tpe)) + c.abort(c.enclosingPosition, s"""We cannot enforce ${T.tpe} is a case class, either it is not a case class or this macro call is possibly enclosed in a class. + This will mean the macro is operating on a non-resolved type.""") + + def matchField(fieldType: Type, fieldName: String, isOption: Boolean): Tree = { + val REPETITION_REQUIRED = q"_root_.parquet.schema.Type.Repetition.REQUIRED" + val REPETITION_OPTIONAL = q"_root_.parquet.schema.Type.Repetition.OPTIONAL" + val REPETITION_REPEATED = q"_root_.parquet.schema.Type.Repetition.REPEATED" + + def repetition: Tree = if (isOption) REPETITION_OPTIONAL else REPETITION_REQUIRED + + def createPrimitiveTypeField(primitiveType: Tree): Tree = + q"""new _root_.parquet.schema.PrimitiveType($repetition, $primitiveType, $fieldName)""" + + def createListGroupType(innerFieldsType: Tree): Tree = + q"""new _root_.parquet.schema.GroupType($REPETITION_REPEATED, "list", $innerFieldsType)""" + + fieldType match { + case tpe if tpe =:= typeOf[String] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY") + case tpe if tpe =:= typeOf[Boolean] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN") + case tpe if tpe =:= typeOf[Short] || tpe =:= typeOf[Int] || tpe =:= typeOf[Byte] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32") + case tpe if tpe =:= typeOf[Long] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64") + case tpe if tpe =:= typeOf[Float] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT") + case tpe if tpe =:= typeOf[Double] => + createPrimitiveTypeField(q"_root_.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE") + case tpe if tpe.erasure =:= typeOf[Option[Any]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + matchField(innerType, fieldName, isOption = true) + case tpe if tpe.erasure =:= typeOf[List[Any]] || tpe.erasure =:= typeOf[Set[_]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + val innerFieldsType = matchField(innerType, "element", isOption = false) + q"_root_.parquet.schema.ConversionPatterns.listType($repetition, $fieldName, ${createListGroupType(innerFieldsType)})" + case tpe if tpe.erasure =:= typeOf[Map[_, Any]] => + val List(keyType, valueType) = tpe.asInstanceOf[TypeRefApi].args + val keyFieldType = matchField(keyType, "key", isOption = false) + val valueFieldType = matchField(valueType, "value", isOption = false) + q"_root_.parquet.schema.ConversionPatterns.mapType($repetition, $fieldName, $keyFieldType, $valueFieldType)" + case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) => + q"new _root_.parquet.schema.GroupType($repetition, $fieldName, ..${expandMethod(tpe)})" + case _ => c.abort(c.enclosingPosition, s"Case class $T has unsupported field type : $fieldType ") + } + } + + def expandMethod(outerTpe: Type): List[Tree] = { + outerTpe + .declarations + .collect { case m: MethodSymbol if m.isCaseAccessor => m } + .map { accessorMethod => + val fieldName = accessorMethod.name.toTermName.toString + val fieldType = accessorMethod.returnType + matchField(fieldType, fieldName, isOption = false) + }.toList + } + + val expanded = expandMethod(T.tpe) + + if (expanded.isEmpty) + c.abort(c.enclosingPosition, s"Case class $T.tpe has no fields we were able to extract") + + val messageTypeName = s"${T.tpe}".split("\\.").last + val schema = q"""new _root_.parquet.schema.MessageType($messageTypeName, + _root_.scala.Array.apply[_root_.parquet.schema.Type](..$expanded):_*).toString""" + + c.Expr[String](schema) + } +} diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetTupleConverterProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetTupleConverterProvider.scala new file mode 100644 index 0000000000..aed795ae67 --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetTupleConverterProvider.scala @@ -0,0 +1,199 @@ +package com.twitter.scalding.parquet.tuple.macros.impl + +import com.twitter.bijection.macros.impl.IsCaseClassImpl +import com.twitter.scalding.parquet.tuple.scheme._ + +import scala.reflect.macros.Context + +object ParquetTupleConverterProvider { + private[this] sealed trait CollectionType + private[this] case object NOT_A_COLLECTION extends CollectionType + private[this] case object OPTION extends CollectionType + private[this] case object LIST extends CollectionType + private[this] case object SET extends CollectionType + private[this] case object MAP extends CollectionType + + def toParquetTupleConverterImpl[T](ctx: Context)(implicit T: ctx.WeakTypeTag[T]): ctx.Expr[ParquetTupleConverter[T]] = { + import ctx.universe._ + + if (!IsCaseClassImpl.isCaseClassType(ctx)(T.tpe)) + ctx.abort(ctx.enclosingPosition, + s"""We cannot enforce ${T.tpe} is a case class, + either it is not a case class or this macro call is possibly enclosed in a class. + This will mean the macro is operating on a non-resolved type.""") + + def buildGroupConverter(tpe: Type, converters: List[Tree], converterGetters: List[Tree], + converterResetCalls: List[Tree], valueBuilder: Tree): Tree = + q"""new _root_.com.twitter.scalding.parquet.tuple.scheme.ParquetTupleConverter[$tpe]{ + ..$converters + + override def currentValue: $tpe = $valueBuilder + + override def getConverter(i: Int): _root_.parquet.io.api.Converter = { + ..$converterGetters + throw new RuntimeException("invalid index: " + i) + } + + override def reset(): Unit = { + ..$converterResetCalls + } + + }""" + + def matchField(idx: Int, fieldType: Type, collectionType: CollectionType): (Tree, Tree, Tree, Tree) = { + def fieldConverter(converterName: TermName, converter: Tree, isPrimitive: Boolean = false): Tree = { + def primitiveCollectionElementConverter: Tree = + q"""override val child: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[$fieldType] = + new _root_.com.twitter.scalding.parquet.tuple.scheme.CollectionElementPrimitiveConverter[$fieldType](this) { + override val delegate: _root_.com.twitter.scalding.parquet.tuple.scheme.PrimitiveFieldConverter[$fieldType] = $converter + } + """ + + def caseClassFieldCollectionElementConverter: Tree = + q"""override val child: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[$fieldType] = + new _root_.com.twitter.scalding.parquet.tuple.scheme.CollectionElementGroupConverter[$fieldType](this) { + override val delegate: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[$fieldType] = $converter + } + """ + + collectionType match { + case OPTION => + val child = if (isPrimitive) primitiveCollectionElementConverter else caseClassFieldCollectionElementConverter + q""" + val $converterName = new _root_.com.twitter.scalding.parquet.tuple.scheme.OptionConverter[$fieldType] { + $child + } + """ + case LIST => + val child = if (isPrimitive) primitiveCollectionElementConverter else caseClassFieldCollectionElementConverter + q""" + val $converterName = new _root_.com.twitter.scalding.parquet.tuple.scheme.ListConverter[$fieldType] { + $child + } + """ + case SET => + val child = if (isPrimitive) primitiveCollectionElementConverter else caseClassFieldCollectionElementConverter + + q""" + val $converterName = new _root_.com.twitter.scalding.parquet.tuple.scheme.SetConverter[$fieldType] { + $child + } + """ + case MAP => converter + case _ => q"val $converterName = $converter" + } + + } + + def createMapFieldConverter(converterName: TermName, K: Type, V: Type, keyConverter: Tree, + valueConverter: Tree): Tree = + q"""val $converterName = new _root_.com.twitter.scalding.parquet.tuple.scheme.MapConverter[$K, $V] { + + override val child: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[($K, $V)] = + new _root_.com.twitter.scalding.parquet.tuple.scheme.MapKeyValueConverter[$K, $V](this) { + override val keyConverter: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[$K] = $keyConverter + override val valueConverter: _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[$V] = $valueConverter + } + } + """ + + def createFieldMatchResult(converterName: TermName, converter: Tree): (Tree, Tree, Tree, Tree) = { + val converterGetter: Tree = q"if($idx == i) return $converterName" + val converterResetCall: Tree = q"$converterName.reset()" + val converterFieldValue: Tree = q"$converterName.currentValue" + (converter, converterGetter, converterResetCall, converterFieldValue) + } + + def matchPrimitiveField(converterType: Type): (Tree, Tree, Tree, Tree) = { + val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val innerConverter: Tree = q"new $converterType()" + val converter: Tree = fieldConverter(converterName, innerConverter, isPrimitive = true) + createFieldMatchResult(converterName, converter) + } + + def matchCaseClassField(groupConverter: Tree): (Tree, Tree, Tree, Tree) = { + val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val converter: Tree = fieldConverter(converterName, groupConverter) + createFieldMatchResult(converterName, converter) + } + + def matchMapField(K: Type, V: Type, keyConverter: Tree, valueConverter: Tree): (Tree, Tree, Tree, Tree) = { + val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val mapConverter = createMapFieldConverter(converterName, K, V, keyConverter, valueConverter) + createFieldMatchResult(converterName, mapConverter) + } + + fieldType match { + case tpe if tpe =:= typeOf[String] => + matchPrimitiveField(typeOf[StringConverter]) + case tpe if tpe =:= typeOf[Boolean] => + matchPrimitiveField(typeOf[BooleanConverter]) + case tpe if tpe =:= typeOf[Byte] => + matchPrimitiveField(typeOf[ByteConverter]) + case tpe if tpe =:= typeOf[Short] => + matchPrimitiveField(typeOf[ShortConverter]) + case tpe if tpe =:= typeOf[Int] => + matchPrimitiveField(typeOf[IntConverter]) + case tpe if tpe =:= typeOf[Long] => + matchPrimitiveField(typeOf[LongConverter]) + case tpe if tpe =:= typeOf[Float] => + matchPrimitiveField(typeOf[FloatConverter]) + case tpe if tpe =:= typeOf[Double] => + matchPrimitiveField(typeOf[DoubleConverter]) + case tpe if tpe.erasure =:= typeOf[Option[Any]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + matchField(idx, innerType, OPTION) + case tpe if tpe.erasure =:= typeOf[List[Any]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + matchField(idx, innerType, LIST) + case tpe if tpe.erasure =:= typeOf[Set[_]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + matchField(idx, innerType, SET) + case tpe if tpe.erasure =:= typeOf[Map[_, Any]] => + val List(keyType, valueType) = tpe.asInstanceOf[TypeRefApi].args + val (keyConverter, _, _, _) = matchField(0, keyType, MAP) + val (valueConverter, _, _, _) = matchField(0, valueType, MAP) + matchMapField(keyType, valueType, keyConverter, valueConverter) + case tpe if IsCaseClassImpl.isCaseClassType(ctx)(tpe) => + val (innerConverters, innerConvertersGetters, innerConvertersResetCalls, innerFieldValues) = unzip(expandMethod(tpe)) + val innerValueBuilderTree = buildTupleValue(tpe, innerFieldValues) + val converterTree: Tree = buildGroupConverter(tpe, innerConverters, innerConvertersGetters, + innerConvertersResetCalls, innerValueBuilderTree) + matchCaseClassField(converterTree) + case _ => ctx.abort(ctx.enclosingPosition, s"Case class $T has unsupported field type : $fieldType ") + } + } + + def expandMethod(outerTpe: Type): List[(Tree, Tree, Tree, Tree)] = + outerTpe + .declarations + .collect { case m: MethodSymbol if m.isCaseAccessor => m } + .zipWithIndex + .map { + case (accessorMethod, idx) => + val fieldType = accessorMethod.returnType + matchField(idx, fieldType, NOT_A_COLLECTION) + }.toList + + def unzip(treeTuples: List[(Tree, Tree, Tree, Tree)]): (List[Tree], List[Tree], List[Tree], List[Tree]) = { + val emptyTreeList = List[Tree]() + treeTuples.foldRight(emptyTreeList, emptyTreeList, emptyTreeList, emptyTreeList) { + case ((t1, t2, t3, t4), (l1, l2, l3, l4)) => + (t1 :: l1, t2 :: l2, t3 :: l3, t4 :: l4) + } + } + + def buildTupleValue(tpe: Type, fieldValueBuilders: List[Tree]): Tree = { + if (fieldValueBuilders.isEmpty) + ctx.abort(ctx.enclosingPosition, s"Case class $tpe has no primitive types we were able to extract") + val companion = tpe.typeSymbol.companionSymbol + q"$companion(..$fieldValueBuilders)" + } + + val (converters, converterGetters, convertersResetCalls, fieldValues) = unzip(expandMethod(T.tpe)) + val groupConverter = buildGroupConverter(T.tpe, converters, converterGetters, convertersResetCalls, + buildTupleValue(T.tpe, fieldValues)) + + ctx.Expr[ParquetTupleConverter[T]](groupConverter) + } +} diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala new file mode 100644 index 0000000000..a1fadfcad3 --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala @@ -0,0 +1,139 @@ +package com.twitter.scalding.parquet.tuple.macros.impl + +import com.twitter.bijection.macros.impl.IsCaseClassImpl +import parquet.io.api.RecordConsumer +import parquet.schema.MessageType + +import scala.reflect.macros.Context + +object WriteSupportProvider { + + def toWriteSupportImpl[T](ctx: Context)(implicit T: ctx.WeakTypeTag[T]): ctx.Expr[(T, RecordConsumer, MessageType) => Unit] = { + import ctx.universe._ + + if (!IsCaseClassImpl.isCaseClassType(ctx)(T.tpe)) + ctx.abort(ctx.enclosingPosition, + s"""We cannot enforce ${T.tpe} is a case class, + either it is not a case class or this macro call is possibly enclosed in a class. + This will mean the macro is operating on a non-resolved type.""") + + def matchField(idx: Int, fieldType: Type, fValue: Tree, groupName: TermName): (Int, Tree) = { + def writePrimitiveField(wTree: Tree) = + (idx + 1, q"""rc.startField($groupName.getFieldName($idx), $idx) + $wTree + rc.endField($groupName.getFieldName($idx), $idx)""") + + def writeGroupField(subTree: Tree) = + q"""rc.startField($groupName.getFieldName($idx), $idx) + rc.startGroup() + $subTree + rc.endGroup() + rc.endField($groupName.getFieldName($idx), $idx) + """ + def writeCollectionField(elementGroupName: TermName, subTree: Tree) = + writeGroupField(q"""if(!$fValue.isEmpty) { + val $elementGroupName = $groupName.getType($idx).asGroupType.getType(0).asGroupType + $subTree + } + """) + + fieldType match { + case tpe if tpe =:= typeOf[String] => + writePrimitiveField(q"rc.addBinary(Binary.fromString($fValue))") + case tpe if tpe =:= typeOf[Boolean] => + writePrimitiveField(q"rc.addBoolean($fValue)") + case tpe if tpe =:= typeOf[Short] => + writePrimitiveField(q"rc.addInteger($fValue.toInt)") + case tpe if tpe =:= typeOf[Int] => + writePrimitiveField(q"rc.addInteger($fValue)") + case tpe if tpe =:= typeOf[Long] => + writePrimitiveField(q"rc.addLong($fValue)") + case tpe if tpe =:= typeOf[Float] => + writePrimitiveField(q"rc.addFloat($fValue)") + case tpe if tpe =:= typeOf[Double] => + writePrimitiveField(q"rc.addDouble($fValue)") + case tpe if tpe =:= typeOf[Byte] => + writePrimitiveField(q"rc.addInteger($fValue.toInt)") + case tpe if tpe.erasure =:= typeOf[Option[Any]] => + val cacheName = newTermName(ctx.fresh(s"optionIndex")) + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + val (_, subTree) = matchField(idx, innerType, q"$cacheName", groupName) + (idx + 1, q"""if($fValue.isDefined) { + val $cacheName = $fValue.get + $subTree + } + """) + case tpe if tpe.erasure =:= typeOf[List[Any]] || tpe.erasure =:= typeOf[Set[_]] => + val innerType = tpe.asInstanceOf[TypeRefApi].args.head + val newGroupName = createGroupName() + val (_, subTree) = matchField(0, innerType, q"element", newGroupName) + (idx + 1, writeCollectionField(newGroupName, q""" + rc.startField("list", 0) + $fValue.foreach{ element => + rc.startGroup() + $subTree + rc.endGroup + } + rc.endField("list", 0)""")) + case tpe if tpe.erasure =:= typeOf[Map[_, Any]] => + val List(keyType, valueType) = tpe.asInstanceOf[TypeRefApi].args + val newGroupName = createGroupName() + val (_, keySubTree) = matchField(0, keyType, q"key", newGroupName) + val (_, valueSubTree) = matchField(1, valueType, q"value", newGroupName) + (idx + 1, writeCollectionField(newGroupName, q""" + rc.startField("map", 0) + $fValue.foreach{ case(key, value) => + rc.startGroup() + $keySubTree + $valueSubTree + rc.endGroup + } + rc.endField("map", 0)""")) + case tpe if IsCaseClassImpl.isCaseClassType(ctx)(tpe) => + val newGroupName = createGroupName() + val (_, subTree) = expandMethod(tpe, fValue, newGroupName) + (idx + 1, + q""" + val $newGroupName = $groupName.getType($idx).asGroupType() + ${writeGroupField(subTree)}""") + + case _ => ctx.abort(ctx.enclosingPosition, s"Case class $T has unsupported field type : $fieldType") + } + } + + def expandMethod(outerTpe: Type, pValueTree: Tree, groupName: TermName): (Int, Tree) = { + outerTpe + .declarations + .collect { case m: MethodSymbol if m.isCaseAccessor => m } + .foldLeft((0, q"")) { + case ((idx, existingTree), getter) => + val (newIdx, subTree) = matchField(idx, getter.returnType, q"$pValueTree.$getter", groupName) + (newIdx, q""" + $existingTree + $subTree + """) + } + } + + def createGroupName(): TermName = newTermName(ctx.fresh("group")) + + val rootGroupName = createGroupName() + + val (finalIdx, funcBody) = expandMethod(T.tpe, q"t", rootGroupName) + + if (finalIdx == 0) + ctx.abort(ctx.enclosingPosition, "Didn't consume any elements in the tuple, possibly empty case class?") + + val writeFunction: Tree = q""" + val writeFunc = (t: $T, rc: _root_.parquet.io.api.RecordConsumer, schema: _root_.parquet.schema.MessageType) => { + + var $rootGroupName: _root_.parquet.schema.GroupType = schema + rc.startMessage + $funcBody + rc.endMessage + } + writeFunc + """ + ctx.Expr[(T, RecordConsumer, MessageType) => Unit](writeFunction) + } +} diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/ParquetTupleConverter.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/ParquetTupleConverter.scala new file mode 100644 index 0000000000..a407eb5d36 --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/ParquetTupleConverter.scala @@ -0,0 +1,330 @@ +package com.twitter.scalding.parquet.tuple.scheme + +import parquet.io.api.{ Binary, Converter, GroupConverter, PrimitiveConverter } +import scala.util.Try + +trait TupleFieldConverter[+T] extends Converter { + /** + * Current value read from parquet column + */ + def currentValue: T + + /** + * reset the converter state, make it ready for reading next column value. + */ + def reset(): Unit +} + +/** + * Parquet tuple converter used to create user defined tuple value from parquet column values + */ +abstract class ParquetTupleConverter[T] extends GroupConverter with TupleFieldConverter[T] { + override def start(): Unit = reset() + override def end(): Unit = () +} + +/** + * Primitive fields converter + * @tparam T primitive types (String, Double, Float, Long, Int, Short, Byte, Boolean) + */ +trait PrimitiveFieldConverter[T] extends PrimitiveConverter with TupleFieldConverter[T] { + val defaultValue: T + var value: T = defaultValue + + override def currentValue: T = value + + override def reset(): Unit = value = defaultValue +} + +class StringConverter extends PrimitiveFieldConverter[String] { + override val defaultValue: String = null + + override def addBinary(binary: Binary): Unit = value = binary.toStringUsingUTF8 +} + +class DoubleConverter extends PrimitiveFieldConverter[Double] { + override val defaultValue: Double = 0D + + override def addDouble(v: Double): Unit = value = v +} + +class FloatConverter extends PrimitiveFieldConverter[Float] { + override val defaultValue: Float = 0F + + override def addFloat(v: Float): Unit = value = v +} + +class LongConverter extends PrimitiveFieldConverter[Long] { + override val defaultValue: Long = 0L + + override def addLong(v: Long): Unit = value = v +} + +class IntConverter extends PrimitiveFieldConverter[Int] { + override val defaultValue: Int = 0 + + override def addInt(v: Int): Unit = value = v +} + +class ShortConverter extends PrimitiveFieldConverter[Short] { + override val defaultValue: Short = 0 + + override def addInt(v: Int): Unit = value = Try(v.toShort).getOrElse(0) +} + +class ByteConverter extends PrimitiveFieldConverter[Byte] { + override val defaultValue: Byte = 0 + + override def addInt(v: Int): Unit = value = Try(v.toByte).getOrElse(0) +} + +class BooleanConverter extends PrimitiveFieldConverter[Boolean] { + override val defaultValue: Boolean = false + + override def addBoolean(v: Boolean): Unit = value = v +} + +/** + * Collection field converter, such as list(Scala Option is also seen as a collection). + * @tparam T collection element type(can be primitive types or nested types) + */ +trait CollectionConverter[T] { + val child: TupleFieldConverter[T] + + def appendValue(v: T): Unit +} + +/** + * A wrapper of primitive converters for modeling primitive fields in a collection + * @tparam T primitive types (String, Double, Float, Long, Int, Short, Byte, Boolean) + */ +abstract class CollectionElementPrimitiveConverter[T](val parent: CollectionConverter[T]) extends PrimitiveConverter + with TupleFieldConverter[T] { + val delegate: PrimitiveFieldConverter[T] + + override def addBinary(v: Binary) = { + delegate.addBinary(v) + parent.appendValue(delegate.currentValue) + } + + override def addBoolean(v: Boolean) = { + delegate.addBoolean(v) + parent.appendValue(delegate.currentValue) + } + + override def addDouble(v: Double) = { + delegate.addDouble(v) + parent.appendValue(delegate.currentValue) + } + + override def addFloat(v: Float) = { + delegate.addFloat(v) + parent.appendValue(delegate.currentValue) + } + + override def addInt(v: Int) = { + delegate.addInt(v) + parent.appendValue(delegate.currentValue) + } + + override def addLong(v: Long) = { + delegate.addLong(v) + parent.appendValue(delegate.currentValue) + } + + override def currentValue: T = delegate.currentValue + + override def reset(): Unit = delegate.reset() +} + +/** + * A wrapper of group converters for modeling group type element in a collection + * @tparam T group tuple type(can be a collection type, such as list) + */ +abstract class CollectionElementGroupConverter[T](val parent: CollectionConverter[T]) extends GroupConverter + with TupleFieldConverter[T] { + + val delegate: TupleFieldConverter[T] + + override def getConverter(i: Int): Converter = delegate.asGroupConverter().getConverter(i) + + override def end(): Unit = { + parent.appendValue(delegate.currentValue) + delegate.asGroupConverter().end() + } + + override def start(): Unit = delegate.asGroupConverter().start() + + override def currentValue: T = delegate.currentValue + + override def reset(): Unit = delegate.reset() +} + +/** + * Option converter for modeling option field + * @tparam T option element type(can be primitive types or nested types) + */ +abstract class OptionConverter[T] extends TupleFieldConverter[Option[T]] with CollectionConverter[T] { + var value: Option[T] = None + + override def appendValue(v: T): Unit = value = Option(v) + + override def currentValue: Option[T] = value + + override def reset(): Unit = { + value = None + child.reset() + } + + override def isPrimitive: Boolean = child.isPrimitive + + override def asGroupConverter: GroupConverter = child.asGroupConverter() + + override def asPrimitiveConverter: PrimitiveConverter = child.asPrimitiveConverter() +} + +/** + * List in parquet is represented by 3-level structure. + * Check this https://github.com/apache/incubator-parquet-format/blob/master/LogicalTypes.md + * Helper class to wrap a converter for a list group converter + */ +object ListElement { + def wrapper(child: Converter): GroupConverter = new GroupConverter() { + override def getConverter(i: Int): Converter = { + if (i != 0) + throw new IllegalArgumentException("list have only one element field. can't reach " + i) + child + } + + override def end(): Unit = () + + override def start(): Unit = () + } +} +/** + * List converter for modeling list field + * @tparam T list element type(can be primitive types or nested types) + */ +abstract class ListConverter[T] extends GroupConverter with TupleFieldConverter[List[T]] with CollectionConverter[T] { + + var value: List[T] = Nil + + def appendValue(v: T): Unit = value = value :+ v + + lazy val listElement: GroupConverter = new GroupConverter() { + override def getConverter(i: Int): Converter = { + if (i != 0) + throw new IllegalArgumentException("lists have only one element field. can't reach " + i) + child + } + + override def end(): Unit = () + + override def start(): Unit = () + } + + override def getConverter(i: Int): Converter = { + if (i != 0) + throw new IllegalArgumentException("lists have only one element field. can't reach " + i) + listElement + } + + override def end(): Unit = () + + override def start(): Unit = reset() + + override def currentValue: List[T] = value + + override def reset(): Unit = { + value = Nil + child.reset() + } +} + +/** + * Set converter for modeling set field + * @tparam T list element type(can be primitive types or nested types) + */ +abstract class SetConverter[T] extends GroupConverter with TupleFieldConverter[Set[T]] with CollectionConverter[T] { + + var value: Set[T] = Set() + + def appendValue(v: T): Unit = value = value + v + + //in the back end set is stored as list + lazy val listElement: GroupConverter = ListElement.wrapper(child) + + override def getConverter(i: Int): Converter = { + if (i != 0) + throw new IllegalArgumentException("sets have only one element field. can't reach " + i) + listElement + } + + override def end(): Unit = () + + override def start(): Unit = reset() + + override def currentValue: Set[T] = value + + override def reset(): Unit = { + value = Set() + child.reset() + } +} + +/** + * Map converter for modeling map field + * @tparam K map key type + * @tparam V map value type + */ +abstract class MapConverter[K, V] extends GroupConverter with TupleFieldConverter[Map[K, V]] with CollectionConverter[(K, V)] { + + var value: Map[K, V] = Map() + + def appendValue(v: (K, V)): Unit = value = value + v + + override def getConverter(i: Int): Converter = { + if (i != 0) + throw new IllegalArgumentException("maps have only one element type key_value(0). can't reach " + i) + child + } + + override def end(): Unit = () + + override def start(): Unit = reset() + + override def currentValue: Map[K, V] = value + + override def reset(): Unit = { + value = Map() + child.reset() + } +} + +abstract class MapKeyValueConverter[K, V](parent: CollectionConverter[(K, V)]) + extends CollectionElementGroupConverter[(K, V)](parent) { + + val keyConverter: TupleFieldConverter[K] + + val valueConverter: TupleFieldConverter[V] + + override lazy val delegate: TupleFieldConverter[(K, V)] = new GroupConverter with TupleFieldConverter[(K, V)] { + override def currentValue: (K, V) = (keyConverter.currentValue, valueConverter.currentValue) + + override def reset(): Unit = { + keyConverter.reset() + valueConverter.reset() + } + + override def getConverter(i: Int): Converter = { + if (i == 0) keyConverter + else if (i == 1) valueConverter + else throw new IllegalArgumentException("key_value has only the key (0) and value (1) fields expected: " + i) + } + + override def end(): Unit = () + + override def start(): Unit = reset() + } +} + diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala new file mode 100644 index 0000000000..be2d632398 --- /dev/null +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala @@ -0,0 +1,148 @@ +package com.twitter.scalding.parquet.tuple.scheme + +import java.util.{ HashMap => JHashMap, Map => JMap } + +import _root_.parquet.filter2.predicate.FilterPredicate +import _root_.parquet.hadoop.api.ReadSupport.ReadContext +import _root_.parquet.hadoop.api.WriteSupport.WriteContext +import _root_.parquet.hadoop.api.{ ReadSupport, WriteSupport } +import _root_.parquet.hadoop.mapred.{ Container, DeprecatedParquetInputFormat, DeprecatedParquetOutputFormat } +import _root_.parquet.io.api._ +import cascading.flow.FlowProcess +import cascading.scheme.{ Scheme, SinkCall, SourceCall } +import cascading.tap.Tap +import cascading.tuple.Tuple +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.{ JobConf, OutputCollector, RecordReader } +import parquet.hadoop.{ ParquetInputFormat, ParquetOutputFormat } +import parquet.schema._ + +/** + * Parquet tuple materializer permits to create user defined type record from parquet tuple values + * @param converter root converter + * @tparam T User defined value type + */ +class ParquetTupleMaterializer[T](val converter: ParquetTupleConverter[T]) extends RecordMaterializer[T] { + override def getCurrentRecord: T = converter.currentValue + + override def getRootConverter: GroupConverter = converter +} + +/** + * Parquet read support used by [[parquet.hadoop.ParquetInputFormat]] to read values from parquet input. + * User must define record schema and parquet tuple converter that permits to convert parquet tuple to user defined type + * For case class types, we provide a macro to generate the schema and the tuple converter so that user + * can define a ParquetReadSupport like this: + * + * case class SampleClass(bool: Boolean, long: Long, float: Float) + * + * class SampleClassReadSupport extends ParquetReadSupport[SampleClass] { + * import com.twitter.scalding.parquet.tuple.macros.Macros._ + * override val tupleConverter: ParquetTupleConverter = caseClassParquetTupleConverter[SampleClass] + * override val rootSchema: String = caseClassParquetSchema[SampleClass] + * } + * + * @tparam T user defined value type + */ +trait ParquetReadSupport[T] extends ReadSupport[T] { + val tupleConverter: ParquetTupleConverter[T] + val rootSchema: String + + lazy val rootType: MessageType = MessageTypeParser.parseMessageType(rootSchema) + + override def init(configuration: Configuration, map: JMap[String, String], messageType: MessageType): ReadContext = + new ReadContext(rootType) + + override def prepareForRead(configuration: Configuration, map: JMap[String, String], messageType: MessageType, + readContext: ReadContext): RecordMaterializer[T] = + new ParquetTupleMaterializer(tupleConverter) +} + +/** + * Parquet write support used by [[parquet.hadoop.ParquetOutputFormat]] to write values to parquet output. + * User must provide record schema and a function which permits to write a used defined case class to parquet store with + * the record consumer and schema definition. + * + * For case class value types, we provide a macro to generate the write support function so that user + * can define a ParquetWriteSupport like this: + * + * class SampleClassWriteSupport extends TupleWriteSupport[SampleClassB] { + * import com.twitter.scalding.parquet.tuple.macros.Macros._ + * + * override def writeRecord(r: SampleClassB, rc: RecordConsumer, schema: MessageType):Unit = + * Macros.caseClassWriteSupport[SampleClassB](r, rc, schema) + * + * override val rootSchema: String = caseClassParquetSchema[SampleClassB] + * } + * + * @tparam T user defined value type + */ +trait ParquetWriteSupport[T] extends WriteSupport[T] { + + var recordConsumer: RecordConsumer = null + + val rootSchema: String + + lazy val rootType: MessageType = MessageTypeParser.parseMessageType(rootSchema) + + override def init(configuration: Configuration): WriteContext = + new WriteSupport.WriteContext(rootType, new JHashMap[String, String]) + + override def prepareForWrite(rc: RecordConsumer): Unit = recordConsumer = rc + + override def write(record: T): Unit = writeRecord(record, recordConsumer, rootType) + + def writeRecord(r: T, rc: RecordConsumer, schema: MessageType): Unit +} + +/** + * Typed parquet tuple scheme. + * @param readSupport read support class + * @param writeSupport write support class + * @param fp filter predicate + * @tparam T tuple value type + */ +class TypedParquetTupleScheme[T](val readSupport: Class[_], val writeSupport: Class[_], + val fp: Option[FilterPredicate] = None) + extends Scheme[JobConf, RecordReader[AnyRef, Container[T]], OutputCollector[AnyRef, T], Array[AnyRef], Array[AnyRef]] { + + type Output = OutputCollector[AnyRef, T] + type Reader = RecordReader[AnyRef, Container[T]] + type TapType = Tap[JobConf, Reader, Output] + type SourceCallType = SourceCall[Array[AnyRef], Reader] + type SinkCallType = SinkCall[Array[AnyRef], Output] + + override def sourceConfInit(flowProcess: FlowProcess[JobConf], tap: TapType, jobConf: JobConf): Unit = { + fp.map(ParquetInputFormat.setFilterPredicate(jobConf, _)) + jobConf.setInputFormat(classOf[DeprecatedParquetInputFormat[T]]) + ParquetInputFormat.setReadSupportClass(jobConf, readSupport) + } + + override def source(flowProcess: FlowProcess[JobConf], sc: SourceCallType): Boolean = { + val value: Container[T] = sc.getInput.createValue() + + val hasNext = sc.getInput.next(null, value) + + if (!hasNext) false + else if (value == null) true + else { + val tuple = new Tuple(value.get.asInstanceOf[AnyRef]) + sc.getIncomingEntry.setTuple(tuple) + true + } + } + + override def sinkConfInit(flowProcess: FlowProcess[JobConf], tap: TapType, jobConf: JobConf): Unit = { + jobConf.setOutputFormat(classOf[DeprecatedParquetOutputFormat[T]]) + ParquetOutputFormat.setWriteSupportClass(jobConf, writeSupport) + } + + override def sink(flowProcess: FlowProcess[JobConf], sinkCall: SinkCallType): Unit = { + val tuple = sinkCall.getOutgoingEntry + require(tuple.size == 1, + "TypedParquetTupleScheme expects tuple with an arity of exactly 1, but found " + tuple.getFields) + val value = tuple.getObject(0).asInstanceOf[T] + val outputCollector = sinkCall.getOutput + outputCollector.collect(null, value) + } +} diff --git a/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/TypedParquetTupleTest.scala b/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/TypedParquetTupleTest.scala new file mode 100644 index 0000000000..0b7483818f --- /dev/null +++ b/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/TypedParquetTupleTest.scala @@ -0,0 +1,105 @@ +package com.twitter.scalding.parquet.tuple + +import com.twitter.scalding.parquet.tuple.macros.Macros +import com.twitter.scalding.parquet.tuple.scheme.{ ParquetTupleConverter, ParquetReadSupport, ParquetWriteSupport } +import com.twitter.scalding.platform.{ HadoopPlatformJobTest, HadoopPlatformTest } +import com.twitter.scalding.typed.TypedPipe +import com.twitter.scalding.{ Args, Job, TypedTsv } +import org.scalatest.{ Matchers, WordSpec } +import parquet.filter2.predicate.FilterApi.binaryColumn +import parquet.filter2.predicate.{ FilterApi, FilterPredicate } +import parquet.io.api.{ RecordConsumer, Binary } +import parquet.schema.MessageType + +class TypedParquetTupleTest extends WordSpec with Matchers with HadoopPlatformTest { + "TypedParquetTuple" should { + + "read and write correctly" in { + import com.twitter.scalding.parquet.tuple.TestValues._ + HadoopPlatformJobTest(new WriteToTypedParquetTupleJob(_), cluster) + .arg("output", "output1") + .sink[SampleClassB](TypedParquet[SampleClassB, BReadSupport](Seq("output1"))) { _.toSet shouldBe values.toSet } + .run + + HadoopPlatformJobTest(new ReadWithFilterPredicateJob(_), cluster) + .arg("input", "output1") + .arg("output", "output2") + .sink[Boolean]("output2") { _.toSet shouldBe values.filter(_.string == "B1").map(_.a.bool).toSet } + .run + + } + } +} + +object TestValues { + val values = Seq( + SampleClassB("B1", Some(4.0D), SampleClassA(bool = true, 5, 1L, 1.2F, 1), List(1, 2), + List(SampleClassD(1, "1"), SampleClassD(2, "2")), Set(1D, 2D), Set(SampleClassF(1, 1F)), Map(1 -> "foo")), + SampleClassB("B2", Some(3.0D), SampleClassA(bool = false, 4, 2L, 2.3F, 2), List(3, 4), Nil, Set(3, 4), Set(), + Map(2 -> "bar"), Map(SampleClassD(0, "z") -> SampleClassF(0, 3), SampleClassD(0, "y") -> SampleClassF(2, 6))), + SampleClassB("B3", None, SampleClassA(bool = true, 6, 3L, 3.4F, 3), List(5, 6), + List(SampleClassD(3, "3"), SampleClassD(4, "4")), Set(5, 6), Set(SampleClassF(2, 2F))), + SampleClassB("B4", Some(5.0D), SampleClassA(bool = false, 7, 4L, 4.5F, 4), Nil, + List(SampleClassD(5, "5"), SampleClassD(6, "6")), Set(), Set(SampleClassF(3, 3F), SampleClassF(5, 4F)), + Map(3 -> "foo2"), Map(SampleClassD(0, "q") -> SampleClassF(4, 3)))) +} + +case class SampleClassA(bool: Boolean, short: Short, long: Long, float: Float, byte: Byte) + +case class SampleClassB(string: String, double: Option[Double], a: SampleClassA, intList: List[Int], + dList: List[SampleClassD], doubleSet: Set[Double], fSet: Set[SampleClassF], intStringMap: Map[Int, String] = Map(), + dfMap: Map[SampleClassD, SampleClassF] = Map()) + +case class SampleClassC(string: String, a: SampleClassA) +case class SampleClassD(x: Int, y: String) +case class SampleClassF(w: Byte, z: Float) + +object SampleClassB { + val schema: String = Macros.caseClassParquetSchema[SampleClassB] +} + +class BReadSupport extends ParquetReadSupport[SampleClassB] { + override val tupleConverter: ParquetTupleConverter[SampleClassB] = Macros.caseClassParquetTupleConverter[SampleClassB] + override val rootSchema: String = SampleClassB.schema +} + +class CReadSupport extends ParquetReadSupport[SampleClassC] { + override val tupleConverter: ParquetTupleConverter[SampleClassC] = Macros.caseClassParquetTupleConverter[SampleClassC] + override val rootSchema: String = Macros.caseClassParquetSchema[SampleClassC] +} + +class WriteSupport extends ParquetWriteSupport[SampleClassB] { + override val rootSchema: String = SampleClassB.schema + override def writeRecord(r: SampleClassB, rc: RecordConsumer, schema: MessageType): Unit = + Macros.caseClassWriteSupport[SampleClassB](r, rc, schema) +} + +/** + * Test job write a sequence of sample class values into a typed parquet tuple. + * To test typed parquet tuple can be used as sink + */ +class WriteToTypedParquetTupleJob(args: Args) extends Job(args) { + import com.twitter.scalding.parquet.tuple.TestValues._ + + val outputPath = args.required("output") + + val sink = TypedParquetSink[SampleClassB, WriteSupport](Seq(outputPath)) + TypedPipe.from(values).write(sink) +} + +/** + * Test job read from a typed parquet source with filter predicate and push down(SampleClassC takes only part of + * SampleClassB's data) + * To test typed parquet tuple can bse used as source and apply filter predicate and push down correctly + */ +class ReadWithFilterPredicateJob(args: Args) extends Job(args) { + val fp: FilterPredicate = FilterApi.eq(binaryColumn("string"), Binary.fromString("B1")) + + val inputPath = args.required("input") + val outputPath = args.required("output") + + val input = TypedParquet[SampleClassC, CReadSupport](Seq(inputPath), Some(fp)) + + TypedPipe.from(input).map(_.a.bool).write(TypedTsv[Boolean](outputPath)) +} + diff --git a/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/macros/MacroUnitTests.scala b/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/macros/MacroUnitTests.scala new file mode 100644 index 0000000000..d7be8f650f --- /dev/null +++ b/scalding-parquet/src/test/scala/com/twitter/scalding/parquet/tuple/macros/MacroUnitTests.scala @@ -0,0 +1,618 @@ +package com.twitter.scalding.parquet.tuple.macros + +import org.scalatest.mock.MockitoSugar +import org.scalatest.{ Matchers, WordSpec } +import parquet.io.api.{ Binary, RecordConsumer } +import parquet.schema.MessageTypeParser + +case class SampleClassA(x: Int, y: String) + +case class SampleClassB(a: SampleClassA, y: String) + +case class SampleClassC(a: SampleClassA, b: SampleClassB) + +case class SampleClassD(a: String, b: Boolean, c: Option[Short], d: Int, e: Long, f: Float, g: Option[Double]) + +case class SampleClassE(a: Int, b: Long, c: Short, d: Boolean, e: Float, f: Double, g: String, h: Byte) + +case class SampleClassF(a: Int, b: Option[SampleClassB], c: Double) + +case class SampleClassG(a: Int, b: Option[List[Double]]) + +case class SampleClassH(a: Int, b: List[SampleClassA]) + +case class SampleClassI(a: Int, b: List[Option[Double]]) + +case class SampleClassJ(a: Map[Int, String]) + +case class SampleClassK(a: String, b: Map[SampleClassA, SampleClassB]) + +class MacroUnitTests extends WordSpec with Matchers with MockitoSugar { + + "Macro case class parquet schema generator" should { + + "Generate parquet schema for SampleClassA" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassA]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassA { + | required int32 x; + | required binary y; + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassB" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassB]) + + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassB { + | required group a { + | required int32 x; + | required binary y; + | } + | required binary y; + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassC" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassC]) + + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassC { + | required group a { + | required int32 x; + | required binary y; + | } + | required group b { + | required group a { + | required int32 x; + | required binary y; + | } + | required binary y; + | } + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassD" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassD]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassD { + | required binary a; + | required boolean b; + | optional int32 c; + | required int32 d; + | required int64 e; + | required float f; + | optional double g; + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassE" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassE]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassE { + | required int32 a; + | required int64 b; + | required int32 c; + | required boolean d; + | required float e; + | required double f; + | required binary g; + | required int32 h; + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassG" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassG]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassG { + | required int32 a; + | optional group b (LIST) { + | repeated group list { + | required double element; + | } + | } + |} + | + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassH" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassH]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassH { + | required int32 a; + | required group b (LIST) { + | repeated group list { + | required group element { + | required int32 x; + | required binary y; + | } + | } + | } + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassI" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassI]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassI { + | required int32 a; + | required group b (LIST) { + | repeated group list { + | optional double element; + | } + | } + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassJ" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassJ]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + |message SampleClassJ { + | required group a (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value; + | } + | } + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + + "Generate parquet schema for SampleClassK" in { + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassK]) + val expectedSchema = MessageTypeParser.parseMessageType(""" + message SampleClassK { + | required binary a; + | required group b (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required group key { + | required int32 x; + | required binary y; + | } + | required group value { + | required group a { + | required int32 x; + | required binary y; + | } + | required binary y; + | } + | } + | } + |} + """.stripMargin) + schema shouldEqual expectedSchema + } + } + + "Macro case class converters generator" should { + + "Generate converters for all primitive types" in { + val converter = Macros.caseClassParquetTupleConverter[SampleClassE] + converter.start() + val intConverter = converter.getConverter(0).asPrimitiveConverter() + intConverter.addInt(0) + + val longConverter = converter.getConverter(1).asPrimitiveConverter() + longConverter.addLong(1L) + + val shortConverter = converter.getConverter(2).asPrimitiveConverter() + shortConverter.addInt(2) + + val boolean = converter.getConverter(3).asPrimitiveConverter() + boolean.addBoolean(true) + + val float = converter.getConverter(4).asPrimitiveConverter() + float.addFloat(3F) + + val double = converter.getConverter(5).asPrimitiveConverter() + double.addDouble(4D) + + val string = converter.getConverter(6).asPrimitiveConverter() + string.addBinary(Binary.fromString("foo")) + + val byte = converter.getConverter(7).asPrimitiveConverter() + byte.addInt(1) + converter.end() + converter.currentValue shouldEqual SampleClassE(0, 1L, 2, d = true, 3F, 4D, "foo", 1) + } + + "Generate converters for case class with nested class" in { + val converter = Macros.caseClassParquetTupleConverter[SampleClassB] + converter.start() + val a = converter.getConverter(0).asGroupConverter() + + a.start() + val aInt = a.getConverter(0).asPrimitiveConverter() + aInt.addInt(2) + val aString = a.getConverter(1).asPrimitiveConverter() + aString.addBinary(Binary.fromString("foo")) + a.end() + + val bString = converter.getConverter(1).asPrimitiveConverter() + bString.addBinary(Binary.fromString("toto")) + converter.end() + converter.currentValue shouldEqual SampleClassB(SampleClassA(2, "foo"), "toto") + } + + "Generate converters for case class with optional nested class" in { + val converter = Macros.caseClassParquetTupleConverter[SampleClassF] + converter.start() + val a = converter.getConverter(0).asPrimitiveConverter() + a.addInt(0) + + val b = converter.getConverter(1).asGroupConverter() + b.start() + val ba = b.getConverter(0).asGroupConverter() + ba.start() + val baInt = ba.getConverter(0).asPrimitiveConverter() + baInt.addInt(2) + val baString = ba.getConverter(1).asPrimitiveConverter() + baString.addBinary(Binary.fromString("foo")) + ba.end() + + val bString = b.getConverter(1).asPrimitiveConverter() + bString.addBinary(Binary.fromString("b1")) + b.end() + + val c = converter.getConverter(2).asPrimitiveConverter() + c.addDouble(4D) + converter.end() + converter.currentValue shouldEqual SampleClassF(0, Some(SampleClassB(SampleClassA(2, "foo"), "b1")), 4D) + } + + "Generate converters for case class with list fields" in { + val converter = Macros.caseClassParquetTupleConverter[SampleClassF] + converter.start() + val a = converter.getConverter(0).asPrimitiveConverter() + a.addInt(0) + + val b = converter.getConverter(1).asGroupConverter() + b.start() + val ba = b.getConverter(0).asGroupConverter() + ba.start() + val baInt = ba.getConverter(0).asPrimitiveConverter() + baInt.addInt(2) + val baString = ba.getConverter(1).asPrimitiveConverter() + baString.addBinary(Binary.fromString("foo")) + ba.end() + + val bString = b.getConverter(1).asPrimitiveConverter() + bString.addBinary(Binary.fromString("b1")) + b.end() + + val c = converter.getConverter(2).asPrimitiveConverter() + c.addDouble(4D) + converter.end() + converter.currentValue shouldEqual SampleClassF(0, Some(SampleClassB(SampleClassA(2, "foo"), "b1")), 4D) + } + + "Generate converters for case class with map fields" in { + val converter = Macros.caseClassParquetTupleConverter[SampleClassK] + converter.start() + val a = converter.getConverter(0).asPrimitiveConverter() + a.addBinary(Binary.fromString("foo")) + + val keyValue = converter.getConverter(1).asGroupConverter().getConverter(0).asGroupConverter() + keyValue.start() + val key = keyValue.getConverter(0).asGroupConverter() + key.start() + val keyInt = key.getConverter(0).asPrimitiveConverter() + keyInt.addInt(2) + val keyString = key.getConverter(1).asPrimitiveConverter() + keyString.addBinary(Binary.fromString("bar")) + key.end() + + val value = keyValue.getConverter(1).asGroupConverter() + value.start() + val valueA = value.getConverter(0).asGroupConverter() + valueA.start() + val valueAInt = valueA.getConverter(0).asPrimitiveConverter() + valueAInt.addInt(2) + val valueAString = valueA.getConverter(1).asPrimitiveConverter() + valueAString.addBinary(Binary.fromString("bar")) + valueA.end() + val valueString = value.getConverter(1).asPrimitiveConverter() + valueString.addBinary(Binary.fromString("b1")) + value.end() + keyValue.end() + converter.end() + + converter.currentValue shouldEqual SampleClassK("foo", + Map(SampleClassA(2, "bar") -> SampleClassB(SampleClassA(2, "bar"), "b1"))) + } + } + + "Macro case class parquet write support generator" should { + "Generate write support for class with all the primitive type fields" in { + val writeSupportFn = Macros.caseClassWriteSupport[SampleClassE] + val e = SampleClassE(0, 1L, 2, d = true, 3F, 4D, "foo", 1) + val schema = Macros.caseClassParquetSchema[SampleClassE] + val rc = new StringBuilderRecordConsumer + writeSupportFn(e, rc, MessageTypeParser.parseMessageType(schema)) + + rc.writeScenario shouldEqual """start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field b at 1 + |write INT64 1 + |start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field b at 1 + |write INT64 1 + |end field b at 1 + |start field c at 2 + |write INT32 2 + |end field c at 2 + |start field d at 3 + |write BOOLEAN true + |end field d at 3 + |start field e at 4 + |write FLOAT 3.0 + |end field e at 4 + |start field f at 5 + |write DOUBLE 4.0 + |end field f at 5 + |start field g at 6 + |write BINARY foo + |end field g at 6 + |start field h at 7 + |write INT32 1 + |end field h at 7 + |end message""".stripMargin + + } + + "Generate write support for nested case class and optional fields" in { + val writeSupportFn = Macros.caseClassWriteSupport[SampleClassF] + + val f = SampleClassF(0, Some(SampleClassB(SampleClassA(2, "foo"), "b1")), 4D) + + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassF]) + val rc = new StringBuilderRecordConsumer + writeSupportFn(f, rc, schema) + + rc.writeScenario shouldEqual """start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field b at 1 + |start group + |start field a at 0 + |start group + |start field x at 0 + |write INT32 2 + |end field x at 0 + |start field y at 1 + |write BINARY foo + |end field y at 1 + |end group + |end field a at 0 + |start field y at 1 + |write BINARY b1 + |end field y at 1 + |end group + |end field b at 1 + |start field c at 2 + |write DOUBLE 4.0 + |end field c at 2 + |end message""".stripMargin + + //test write tuple with optional field = None + val f2 = SampleClassF(0, None, 4D) + val rc2 = new StringBuilderRecordConsumer + writeSupportFn(f2, rc2, schema) + rc2.writeScenario shouldEqual """start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field c at 2 + |write DOUBLE 4.0 + |end field c at 2 + |end message""".stripMargin + } + + "Generate write support for case class with LIST fields" in { + //test write tuple with list of primitive fields + val writeSupportFn = Macros.caseClassWriteSupport[SampleClassI] + val i = SampleClassI(0, List(None, Some(2))) + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassI]) + val rc = new StringBuilderRecordConsumer + writeSupportFn(i, rc, schema) + + rc.writeScenario shouldEqual """start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field b at 1 + |start group + |start field list at 0 + |start group + |end group + |start group + |start field element at 0 + |write DOUBLE 2.0 + |end field element at 0 + |end group + |end field list at 0 + |end group + |end field b at 1 + |end message""".stripMargin + //test write list of nested class field + val writeSupportFn2 = Macros.caseClassWriteSupport[SampleClassH] + val h = SampleClassH(0, List(SampleClassA(2, "foo"), SampleClassA(2, "bar"))) + val schema2 = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassH]) + val rc2 = new StringBuilderRecordConsumer + writeSupportFn2(h, rc2, schema2) + + rc2.writeScenario shouldEqual """start message + |start field a at 0 + |write INT32 0 + |end field a at 0 + |start field b at 1 + |start group + |start field list at 0 + |start group + |start field element at 0 + |start group + |start field x at 0 + |write INT32 2 + |end field x at 0 + |start field y at 1 + |write BINARY foo + |end field y at 1 + |end group + |end field element at 0 + |end group + |start group + |start field element at 0 + |start group + |start field x at 0 + |write INT32 2 + |end field x at 0 + |start field y at 1 + |write BINARY bar + |end field y at 1 + |end group + |end field element at 0 + |end group + |end field list at 0 + |end group + |end field b at 1 + |end message""".stripMargin + + } + + "Generate write support for case class with MAP fields" in { + //test write tuple with map of primitive fields + val writeSupportFn = Macros.caseClassWriteSupport[SampleClassJ] + val j = SampleClassJ(Map(1 -> "foo", 2 -> "bar")) + val schema = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassJ]) + val rc = new StringBuilderRecordConsumer + writeSupportFn(j, rc, schema) + rc.writeScenario shouldEqual """start message + |start field a at 0 + |start group + |start field map at 0 + |start group + |start field key at 0 + |write INT32 1 + |end field key at 0 + |start field value at 1 + |write BINARY foo + |end field value at 1 + |end group + |start group + |start field key at 0 + |write INT32 2 + |end field key at 0 + |start field value at 1 + |write BINARY bar + |end field value at 1 + |end group + |end field map at 0 + |end group + |end field a at 0 + |end message""".stripMargin + + //test write Map of case class field + val writeSupportFn2 = Macros.caseClassWriteSupport[SampleClassK] + val k = SampleClassK("foo", Map(SampleClassA(2, "foo") -> SampleClassB(SampleClassA(2, "foo"), "bar"))) + val schema2 = MessageTypeParser.parseMessageType(Macros.caseClassParquetSchema[SampleClassK]) + val rc2 = new StringBuilderRecordConsumer + writeSupportFn2(k, rc2, schema2) + + rc2.writeScenario shouldEqual """start message + |start field a at 0 + |write BINARY foo + |end field a at 0 + |start field b at 1 + |start group + |start field map at 0 + |start group + |start field key at 0 + |start group + |start field x at 0 + |write INT32 2 + |end field x at 0 + |start field y at 1 + |write BINARY foo + |end field y at 1 + |end group + |end field key at 0 + |start field value at 1 + |start group + |start field a at 0 + |start group + |start field x at 0 + |write INT32 2 + |end field x at 0 + |start field y at 1 + |write BINARY foo + |end field y at 1 + |end group + |end field a at 0 + |start field y at 1 + |write BINARY bar + |end field y at 1 + |end group + |end field value at 1 + |end group + |end field map at 0 + |end group + |end field b at 1 + |end message""".stripMargin + + } + } +} + +//class to simulate record consumer for testing +class StringBuilderRecordConsumer extends RecordConsumer { + val sb = new StringBuilder + + override def startMessage(): Unit = sb.append("start message\n") + + override def endMessage(): Unit = sb.append("end message") + + override def addFloat(v: Float): Unit = sb.append(s"write FLOAT $v\n") + + override def addBinary(binary: Binary): Unit = sb.append(s"write BINARY ${binary.toStringUsingUTF8}\n") + + override def addDouble(v: Double): Unit = sb.append(s"write DOUBLE $v\n") + + override def endGroup(): Unit = sb.append("end group\n") + + override def endField(s: String, i: Int): Unit = sb.append(s"end field $s at $i\n") + + override def startGroup(): Unit = sb.append("start group\n") + + override def startField(s: String, i: Int): Unit = sb.append(s"start field $s at $i\n") + + override def addBoolean(b: Boolean): Unit = sb.append(s"write BOOLEAN $b\n") + + override def addLong(l: Long): Unit = sb.append(sb.append(s"write INT64 $l\n")) + + override def addInteger(i: Int): Unit = sb.append(s"write INT32 $i\n") + + def writeScenario = sb.toString() +} \ No newline at end of file diff --git a/scalding-repl/src/main/scala-2.10/com/twitter/scalding/ILoopCompat.scala b/scalding-repl/src/main/scala-2.10/com/twitter/scalding/ILoopCompat.scala index dbf9df69c2..0069d91f11 100644 --- a/scalding-repl/src/main/scala-2.10/com/twitter/scalding/ILoopCompat.scala +++ b/scalding-repl/src/main/scala-2.10/com/twitter/scalding/ILoopCompat.scala @@ -1,5 +1,5 @@ package com.twitter.scalding -import scala.tools.nsc.interpreter.{ ILoop, IMain } +import scala.tools.nsc.interpreter.ILoop trait ILoopCompat extends ILoop diff --git a/scalding-repl/src/main/scala/com/twitter/scalding/ReplImplicits.scala b/scalding-repl/src/main/scala/com/twitter/scalding/ReplImplicits.scala index db851fe4df..6bdb89f9d5 100644 --- a/scalding-repl/src/main/scala/com/twitter/scalding/ReplImplicits.scala +++ b/scalding-repl/src/main/scala/com/twitter/scalding/ReplImplicits.scala @@ -27,6 +27,11 @@ import scala.concurrent.{ Future, ExecutionContext => ConcurrentExecutionContext */ object ReplImplicits extends FieldConversions { + /** required for switching to hdfs local mode */ + private val mr1Key = "mapred.job.tracker" + private val mr2Key = "mapreduce.framework.name" + private val mrLocal = "local" + /** Implicit flowDef for this Scalding shell session. */ var flowDef: FlowDef = getEmptyFlowDef /** Defaults to running in local mode if no mode is specified. */ @@ -42,13 +47,25 @@ object ReplImplicits extends FieldConversions { def useStrictLocalMode() { mode = Local(true) } /** Switch to Hdfs mode */ - def useHdfsMode() { + private def useHdfsMode_() { storedHdfsMode match { case Some(hdfsMode) => mode = hdfsMode case None => println("To use HDFS/Hadoop mode, you must *start* the repl in hadoop mode to get the hadoop configuration from the hadoop command.") } } + def useHdfsMode() { + useHdfsMode_() + customConfig -= mr1Key + customConfig -= mr2Key + } + + def useHdfsLocalMode() { + useHdfsMode_() + customConfig += mr1Key -> mrLocal + customConfig += mr2Key -> mrLocal + } + /** * Configuration to use for REPL executions. * diff --git a/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingShell.scala b/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingShell.scala index c1febab692..e7f0dc03ca 100644 --- a/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingShell.scala +++ b/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingShell.scala @@ -23,7 +23,7 @@ import java.util.jar.JarOutputStream import org.apache.hadoop.util.GenericOptionsParser import org.apache.hadoop.conf.Configuration -import scala.tools.nsc.{ Settings, GenericRunnerCommand, MainGenericRunner } +import scala.tools.nsc.{ GenericRunnerCommand, MainGenericRunner } import scala.tools.nsc.interpreter.ILoop import scala.tools.nsc.io.VirtualDirectory diff --git a/scalding-repl/src/main/scala/com/twitter/scalding/ShellPipe.scala b/scalding-repl/src/main/scala/com/twitter/scalding/ShellPipe.scala index 7c3b62319d..f223625119 100644 --- a/scalding-repl/src/main/scala/com/twitter/scalding/ShellPipe.scala +++ b/scalding-repl/src/main/scala/com/twitter/scalding/ShellPipe.scala @@ -15,8 +15,6 @@ package com.twitter.scalding -import com.twitter.scalding.typed._ - /** * Enrichment on TypedPipes allowing them to be run locally, independent of the overall flow. * @param pipe to wrap diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala new file mode 100644 index 0000000000..3bb815ebf8 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala @@ -0,0 +1,76 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl + +import scala.language.experimental.macros +import scala.reflect.macros.Context +import scala.util.Random + +import com.twitter.scalding.serialization.OrderedSerialization +import com.twitter.scalding.serialization.macros.impl.ordered_serialization._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers._ + +object OrderedSerializationProviderImpl { + def normalizedDispatcher(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if (!tpe.toString.contains(ImplicitOrderedBuf.macroMarker) && !(tpe.normalize == tpe)) => + buildDispatcher(tpe.normalize) + } + + def scaldingBasicDispatchers(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + + val primitiveDispatcher = PrimitiveOrderedBuf.dispatch(c) + val optionDispatcher = OptionOrderedBuf.dispatch(c)(buildDispatcher) + val eitherDispatcher = EitherOrderedBuf.dispatch(c)(buildDispatcher) + val caseClassDispatcher = CaseClassOrderedBuf.dispatch(c)(buildDispatcher) + val productDispatcher = ProductOrderedBuf.dispatch(c)(buildDispatcher) + val stringDispatcher = StringOrderedBuf.dispatch(c) + val traversablesDispatcher = TraversablesOrderedBuf.dispatch(c)(buildDispatcher) + val unitDispatcher = UnitOrderedBuf.dispatch(c) + val byteBufferDispatcher = ByteBufferOrderedBuf.dispatch(c) + + OrderedSerializationProviderImpl.normalizedDispatcher(c)(buildDispatcher) + .orElse(primitiveDispatcher) + .orElse(unitDispatcher) + .orElse(optionDispatcher) + .orElse(eitherDispatcher) + .orElse(stringDispatcher) + .orElse(byteBufferDispatcher) + .orElse(traversablesDispatcher) + .orElse(caseClassDispatcher) + .orElse(productDispatcher) + } + + def fallbackImplicitDispatcher(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = + ImplicitOrderedBuf.dispatch(c) + + private def dispatcher(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + import c.universe._ + def buildDispatcher: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = OrderedSerializationProviderImpl.dispatcher(c) + + scaldingBasicDispatchers(c)(buildDispatcher).orElse(fallbackImplicitDispatcher(c)).orElse { + case tpe: Type => c.abort(c.enclosingPosition, s"""Unable to find OrderedSerialization for type ${tpe}""") + } + } + + def apply[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[OrderedSerialization[T]] = { + import c.universe._ + + val b: TreeOrderedBuf[c.type] = dispatcher(c)(T.tpe) + val res = TreeOrderedBuf.toOrderedSerialization[T](c)(b) + //println(res) + res + } +} diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala new file mode 100644 index 0000000000..7a2640c603 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala @@ -0,0 +1,77 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +sealed trait CompileTimeLengthTypes[C <: Context] { + val ctx: C + def t: ctx.Tree +} +object CompileTimeLengthTypes { + + // Repesents an Int returning + object FastLengthCalculation { + def apply(c: Context)(tree: c.Tree): FastLengthCalculation[c.type] = + new FastLengthCalculation[c.type] { + override val ctx: c.type = c + override val t: c.Tree = tree + } + } + + trait FastLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] + + object MaybeLengthCalculation { + def apply(c: Context)(tree: c.Tree): MaybeLengthCalculation[c.type] = + new MaybeLengthCalculation[c.type] { + override val ctx: c.type = c + override val t: c.Tree = tree + } + } + + trait MaybeLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] + + object ConstantLengthCalculation { + def apply(c: Context)(intArg: Int): ConstantLengthCalculation[c.type] = + new ConstantLengthCalculation[c.type] { + override val toInt = intArg + override val ctx: c.type = c + override val t: c.Tree = { + import c.universe._ + q"$intArg" + } + } + } + + trait ConstantLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] { + def toInt: Int + } + + object NoLengthCalculationAvailable { + def apply(c: Context): NoLengthCalculationAvailable[c.type] = { + new NoLengthCalculationAvailable[c.type] { + override val ctx: c.type = c + override def t = { + import c.universe._ + q"""_root_.scala.sys.error("no length available")""" + } + } + } + } + + trait NoLengthCalculationAvailable[C <: Context] extends CompileTimeLengthTypes[C] +} diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala new file mode 100644 index 0000000000..23f874d6f1 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala @@ -0,0 +1,169 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ + +object ProductLike { + def compareBinary(c: Context)(inputStreamA: c.TermName, inputStreamB: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + + elementData.foldLeft(Option.empty[Tree]) { + case (existingTreeOpt, (tpe, accessorSymbol, tBuf)) => + existingTreeOpt match { + case Some(t) => + val lastCmp = freshT("lastCmp") + Some(q""" + val $lastCmp = $t + if($lastCmp != 0) { + $lastCmp + } else { + ${tBuf.compareBinary(inputStreamA, inputStreamB)} + } + """) + case None => + Some(tBuf.compareBinary(inputStreamA, inputStreamB)) + } + }.getOrElse(q"0") + } + + def hash(c: Context)(element: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + + val currentHash = freshT("last") + + val hashUpdates = elementData.map { + case (tpe, accessorSymbol, tBuf) => + val target = freshT("target") + q""" + val $target = $element.$accessorSymbol + _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${tBuf.hash(target)}) + """ + } + + q""" + var $currentHash: Int = _root_.com.twitter.scalding.serialization.MurmurHashUtils.seed + ..${hashUpdates} + _root_.com.twitter.scalding.serialization.MurmurHashUtils.fmix($currentHash, ${elementData.size}) + """ + } + + def put(c: Context)(inputStream: c.TermName, element: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + val innerElement = freshT("innerElement") + + elementData.foldLeft(q"") { + case (existingTree, (tpe, accessorSymbol, tBuf)) => + q""" + $existingTree + val $innerElement = $element.$accessorSymbol + ${tBuf.put(inputStream, innerElement)} + """ + } + } + + def length(c: Context)(element: c.Tree)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): CompileTimeLengthTypes[c.type] = { + import c.universe._ + import CompileTimeLengthTypes._ + val (constSize, dynamicFunctions, maybeLength, noLength) = + elementData.foldLeft((0, Vector[c.Tree](), Vector[c.Tree](), 0)) { + case ((constantLength, dynamicLength, maybeLength, noLength), (tpe, accessorSymbol, tBuf)) => + + tBuf.length(q"$element.$accessorSymbol") match { + case const: ConstantLengthCalculation[_] => (constantLength + const.asInstanceOf[ConstantLengthCalculation[c.type]].toInt, dynamicLength, maybeLength, noLength) + case f: FastLengthCalculation[_] => (constantLength, dynamicLength :+ f.asInstanceOf[FastLengthCalculation[c.type]].t, maybeLength, noLength) + case m: MaybeLengthCalculation[_] => (constantLength, dynamicLength, maybeLength :+ m.asInstanceOf[MaybeLengthCalculation[c.type]].t, noLength) + case _: NoLengthCalculationAvailable[_] => (constantLength, dynamicLength, maybeLength, noLength + 1) + } + } + + val combinedDynamic = dynamicFunctions.foldLeft(q"""$constSize""") { + case (prev, t) => + q"$prev + $t" + } + + if (noLength > 0) { + NoLengthCalculationAvailable(c) + } else { + if (maybeLength.isEmpty && dynamicFunctions.isEmpty) { + ConstantLengthCalculation(c)(constSize) + } else { + if (maybeLength.isEmpty) { + FastLengthCalculation(c)(combinedDynamic) + } else { + + val const = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen" + val dyn = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen" + val noLen = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation" + // Contains an MaybeLength + val combinedMaybe: Tree = maybeLength.reduce { (hOpt, nxtOpt) => q"""$hOpt + $nxtOpt""" } + if (dynamicFunctions.nonEmpty || constSize != 0) { + MaybeLengthCalculation(c) (q""" + $combinedMaybe match { + case $const(l) => $dyn(l + $combinedDynamic) + case $dyn(l) => $dyn(l + $combinedDynamic) + case $noLen => $noLen + } + """) + } else { + MaybeLengthCalculation(c)(combinedMaybe) + } + } + } + } + } + + def compare(c: Context)(elementA: c.TermName, elementB: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + import c.universe._ + + def freshT(id: String) = newTermName(c.fresh(id)) + + val innerElementA = freshT("innerElementA") + val innerElementB = freshT("innerElementB") + + elementData.map { + case (tpe, accessorSymbol, tBuf) => + val curCmp = freshT("curCmp") + val cmpTree = q""" + val $curCmp: Int = { + val $innerElementA = $elementA.$accessorSymbol + val $innerElementB = $elementB.$accessorSymbol + ${tBuf.compare(innerElementA, innerElementB)} + } + """ + (cmpTree, curCmp) + } + .reverse // go through last to first + .foldLeft(None: Option[Tree]) { + case (Some(rest), (tree, valname)) => + Some( + q"""$tree; + if ($valname != 0) $valname + else { + $rest + } + """) + case (None, (tree, valname)) => Some(q"""$tree; $valname""") + } + .getOrElse(q"""0""") // all 0 size products are equal + } +} diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala new file mode 100644 index 0000000000..d336e8d70a --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala @@ -0,0 +1,299 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.OrderedSerialization +import com.twitter.scalding.serialization.JavaStreamEnrichments +import java.io.InputStream +import scala.reflect.macros.Context +import scala.language.experimental.macros +import scala.util.control.NonFatal + +object CommonCompareBinary { + import com.twitter.scalding.serialization.JavaStreamEnrichments._ + + // If the lengths are equal and greater than this number + // we will compare on all the containing bytes + val minSizeForFulBinaryCompare = 24 + + /** + * This method will compare two InputStreams of given lengths + * If the inputsteam supports mark/reset (such as those backed by Array[Byte]), + * and the lengths are equal and longer than minSizeForFulBinaryCompare we first + * check if they are byte-for-byte identical, which is a cheap way to avoid doing + * potentially complex logic in binary comparators + */ + final def earlyEqual(inputStreamA: InputStream, + lenA: Int, + inputStreamB: InputStream, + lenB: Int): Boolean = + (lenA > minSizeForFulBinaryCompare && + (lenA == lenB) && + inputStreamA.markSupported && + inputStreamB.markSupported) && { + inputStreamA.mark(lenA) + inputStreamB.mark(lenB) + + var pos: Int = 0 + while (pos < lenA) { + val a = inputStreamA.read + val b = inputStreamB.read + pos += 1 + if (a != b) { + inputStreamA.reset() + inputStreamB.reset() + // yeah, return sucks, but trying to optimize here + return false + } + // a == b, but may be eof + if (a < 0) return JavaStreamEnrichments.eof + } + // we consumed all the bytes, and they were all equal + true + } +} +object TreeOrderedBuf { + import CompileTimeLengthTypes._ + def toOrderedSerialization[T](c: Context)(t: TreeOrderedBuf[c.type])(implicit T: t.ctx.WeakTypeTag[T]): t.ctx.Expr[OrderedSerialization[T]] = { + import t.ctx.universe._ + def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + val outputLength = freshT("outputLength") + + val innerLengthFn: Tree = { + val element = freshT("element") + + val fnBodyOpt = t.length(q"$element") match { + case _: NoLengthCalculationAvailable[_] => None + case const: ConstantLengthCalculation[_] => None + case f: FastLengthCalculation[_] => Some(q""" + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(${f.asInstanceOf[FastLengthCalculation[c.type]].t}) + """) + case m: MaybeLengthCalculation[_] => Some(m.asInstanceOf[MaybeLengthCalculation[c.type]].t) + } + + fnBodyOpt.map { fnBody => + q""" + @inline private[this] def payloadLength($element: $T): _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = { + $fnBody + } + """ + }.getOrElse(q"()") + } + + def binaryLengthGen(typeName: Tree): (Tree, Tree) = { + val tempLen = freshT("tempLen") + val lensLen = freshT("lensLen") + val element = freshT("element") + val callDynamic = (q"""override def staticSize: Option[Int] = None""", + q""" + + override def dynamicSize($element: $typeName): Option[Int] = { + val $tempLen = payloadLength($element) match { + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation => None + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(l) => Some(l) + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(l) => Some(l) + } + (if ($tempLen.isDefined) { + // Avoid a closure here while we are geeking out + val innerLen = $tempLen.get + val $lensLen = posVarIntSize(innerLen) + Some(innerLen + $lensLen) + } else None): Option[Int] + } + """) + t.length(q"$element") match { + case _: NoLengthCalculationAvailable[_] => (q""" + override def staticSize: Option[Int] = None""", q""" + override def dynamicSize($element: $typeName): Option[Int] = None""") + case const: ConstantLengthCalculation[_] => (q""" + override val staticSize: Option[Int] = Some(${const.toInt})""", q""" + override def dynamicSize($element: $typeName): Option[Int] = staticSize""") + case f: FastLengthCalculation[_] => callDynamic + case m: MaybeLengthCalculation[_] => callDynamic + } + } + + def putFnGen(outerbaos: TermName, element: TermName) = { + val baos = freshT("baos") + val len = freshT("len") + val oldPos = freshT("oldPos") + + /** + * This is the worst case: we have to serialize in a side buffer + * and then see how large it actually is. This happens for cases, like + * string, where the cost to see the serialized size is not cheaper than + * directly serializing. + */ + val noLenCalc = q""" + // Start with pretty big buffers because reallocation will be expensive + val $baos = new _root_.java.io.ByteArrayOutputStream(256) + ${t.put(baos, element)} + val $len = $baos.size + $outerbaos.writePosVarInt($len) + $baos.writeTo($outerbaos) + """ + + /** + * This is the case where the length is cheap to compute, either + * constant or easily computable from an instance. + */ + def withLenCalc(lenC: Tree) = q""" + val $len = $lenC + $outerbaos.writePosVarInt($len) + ${t.put(outerbaos, element)} + """ + + t.length(q"$element") match { + case _: NoLengthCalculationAvailable[_] => noLenCalc + case _: ConstantLengthCalculation[_] => + q"""${t.put(outerbaos, element)}""" + case f: FastLengthCalculation[_] => + withLenCalc(f.asInstanceOf[FastLengthCalculation[c.type]].t) + case m: MaybeLengthCalculation[_] => + val tmpLenRes = freshT("tmpLenRes") + q""" + @inline def noLenCalc = { + $noLenCalc + } + @inline def withLenCalc(cnt: Int) = { + ${withLenCalc(q"cnt")} + } + val $tmpLenRes: _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = payloadLength($element) + $tmpLenRes match { + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation => noLenCalc + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(const) => withLenCalc(const) + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(s) => withLenCalc(s) + } + """ + } + } + + def readLength(inputStream: TermName) = { + t.length(q"e") match { + case const: ConstantLengthCalculation[_] => q"${const.toInt}" + case _ => q"$inputStream.readPosVarInt" + } + } + + def discardLength(inputStream: TermName) = { + t.length(q"e") match { + case const: ConstantLengthCalculation[_] => q"()" + case _ => q"$inputStream.readPosVarInt" + } + } + + val lazyVariables = t.lazyOuterVariables.map { + case (n, t) => + val termName = newTermName(n) + q"""lazy val $termName = $t""" + } + + val element = freshT("element") + + val inputStreamA = freshT("inputStreamA") + val inputStreamB = freshT("inputStreamB") + val posStreamA = freshT("posStreamA") + val posStreamB = freshT("posStreamB") + + val lenA = freshT("lenA") + val lenB = freshT("lenB") + + t.ctx.Expr[OrderedSerialization[T]](q""" + new _root_.com.twitter.scalding.serialization.OrderedSerialization[$T] { + import _root_.com.twitter.scalding.serialization.JavaStreamEnrichments._ + ..$lazyVariables + + override def compareBinary($inputStreamA: _root_.java.io.InputStream, $inputStreamB: _root_.java.io.InputStream): _root_.com.twitter.scalding.serialization.OrderedSerialization.Result = + try _root_.com.twitter.scalding.serialization.OrderedSerialization.resultFrom { + val $lenA = ${readLength(inputStreamA)} + val $lenB = ${readLength(inputStreamB)} + val $posStreamA = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamA) + val initialPositionA = $posStreamA.position + val $posStreamB = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamB) + val initialPositionB = $posStreamB.position + + val innerR = ${t.compareBinary(posStreamA, posStreamB)} + + $posStreamA.seekToPosition(initialPositionA + $lenA) + $posStreamB.seekToPosition(initialPositionB + $lenB) + innerR + } catch { + case _root_.scala.util.control.NonFatal(e) => + _root_.com.twitter.scalding.serialization.OrderedSerialization.CompareFailure(e) + } + + override def hash(passedInObjectToHash: $T): Int = { + ${t.hash(newTermName("passedInObjectToHash"))} + } + + // defines payloadLength private method + $innerLengthFn + + // static size: + ${binaryLengthGen(q"$T")._1} + + // dynamic size: + ${binaryLengthGen(q"$T")._2} + + override def read(from: _root_.java.io.InputStream): _root_.scala.util.Try[$T] = { + try { + ${discardLength(newTermName("from"))} + _root_.scala.util.Success(${t.get(newTermName("from"))}) + } catch { case _root_.scala.util.control.NonFatal(e) => + _root_.scala.util.Failure(e) + } + } + + override def write(into: _root_.java.io.OutputStream, e: $T): _root_.scala.util.Try[Unit] = { + try { + ${putFnGen(newTermName("into"), newTermName("e"))} + _root_.com.twitter.scalding.serialization.Serialization.successUnit + } catch { case _root_.scala.util.control.NonFatal(e) => + _root_.scala.util.Failure(e) + } + } + + override def compare(x: $T, y: $T): Int = { + ${t.compare(newTermName("x"), newTermName("y"))} + } + } + """) + } +} + +abstract class TreeOrderedBuf[C <: Context] { + val ctx: C + val tpe: ctx.Type + // Expected byte buffers to be in values a and b respestively, the tree has the value of the result + def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName): ctx.Tree + // expects the thing to be tested on in the indiciated TermName + def hash(element: ctx.TermName): ctx.Tree + + // Place input in param 1, tree to return result in param 2 + def get(inputStreamA: ctx.TermName): ctx.Tree + + // BB input in param 1 + // Other input of type T in param 2 + def put(inputStream: ctx.TermName, element: ctx.TermName): ctx.Tree + + def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree + + def lazyOuterVariables: Map[String, ctx.Tree] + // Return the constant size or a tree + def length(element: ctx.universe.Tree): CompileTimeLengthTypes[ctx.type] + +} diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala new file mode 100644 index 0000000000..af26712f42 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala @@ -0,0 +1,99 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ + +import java.nio.ByteBuffer +import com.twitter.scalding.serialization.OrderedSerialization + +object ByteBufferOrderedBuf { + def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe =:= c.universe.typeOf[ByteBuffer] => ByteBufferOrderedBuf(c)(tpe) + } + + def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + + def freshT(id: String) = newTermName(c.fresh(id)) + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def hash(element: ctx.TermName): ctx.Tree = q"$element.hashCode" + + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = { + val lenA = freshT("lenA") + val lenB = freshT("lenB") + val queryLength = freshT("queryLength") + val incr = freshT("incr") + val state = freshT("state") + q""" + val $lenA: Int = $inputStreamA.readPosVarInt + val $lenB: Int = $inputStreamB.readPosVarInt + + val $queryLength = _root_.scala.math.min($lenA, $lenB) + var $incr = 0 + var $state = 0 + + while($incr < $queryLength && $state == 0) { + $state = _root_.java.lang.Byte.compare($inputStreamA.readByte, $inputStreamB.readByte) + $incr = $incr + 1 + } + if($state == 0) { + _root_.java.lang.Integer.compare($lenA, $lenB) + } else { + $state + } + """ + } + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + q""" + $inputStream.writePosVarInt($element.remaining) + $inputStream.writeBytes($element.array, $element.arrayOffset + $element.position, $element.remaining) + """ + + override def get(inputStream: ctx.TermName): ctx.Tree = { + val lenA = freshT("lenA") + val bytes = freshT("bytes") + q""" + val $lenA = $inputStream.readPosVarInt + val $bytes = new Array[Byte]($lenA) + $inputStream.readFully($bytes) + _root_.java.nio.ByteBuffer.wrap($bytes) + """ + } + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = q""" + $elementA.compareTo($elementB) + """ + override def length(element: Tree): CompileTimeLengthTypes[c.type] = { + val tmpLen = freshT("tmpLen") + FastLengthCalculation(c)(q""" + val $tmpLen = $element.remaining + posVarIntSize($tmpLen) + $tmpLen + """) + } + + def lazyOuterVariables: Map[String, ctx.Tree] = Map.empty + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala new file mode 100644 index 0000000000..4f2ab075f3 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala @@ -0,0 +1,86 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import com.twitter.scalding.serialization.OrderedSerialization + +object CaseClassOrderedBuf { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isCaseClass && !tpe.typeConstructor.takesTypeArgs => + CaseClassOrderedBuf(c)(buildDispatcher, tpe) + } + + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + + val dispatcher = buildDispatcher + val elementData: List[(c.universe.Type, TermName, TreeOrderedBuf[c.type])] = + outerType + .declarations + .collect { case m: MethodSymbol if m.isCaseAccessor => m } + .map { accessorMethod => + val fieldType = accessorMethod.returnType + val b: TreeOrderedBuf[c.type] = dispatcher(fieldType) + (fieldType, accessorMethod.name.toTermName, b) + }.toList + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + ProductLike.compareBinary(c)(inputStreamA, inputStreamB)(elementData) + + override def hash(element: ctx.TermName): ctx.Tree = ProductLike.hash(c)(element)(elementData) + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + ProductLike.put(c)(inputStream, element)(elementData) + + override def get(inputStream: ctx.TermName): ctx.Tree = { + + val getValProcessor = elementData.map { + case (tpe, accessorSymbol, tBuf) => + val curR = freshT("curR") + val builderTree = q""" + val $curR = { + ${tBuf.get(inputStream)} + } + """ + (builderTree, curR) + } + q""" + ..${getValProcessor.map(_._1)} + ${outerType.typeSymbol.companionSymbol}(..${getValProcessor.map(_._2)}) + """ + } + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + ProductLike.compare(c)(elementA, elementB)(elementData) + + override val lazyOuterVariables: Map[String, ctx.Tree] = + elementData.map(_._3.lazyOuterVariables).reduce(_ ++ _) + + override def length(element: Tree) = + ProductLike.length(c)(element)(elementData) + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala new file mode 100644 index 0000000000..541302c2c0 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala @@ -0,0 +1,180 @@ +/* + Copyright 2015 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import com.twitter.scalding.serialization.OrderedSerialization + +object EitherOrderedBuf { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe.erasure =:= c.universe.typeOf[Either[Any, Any]] => EitherOrderedBuf(c)(buildDispatcher, tpe) + } + + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + val dispatcher = buildDispatcher + + val leftType = outerType.asInstanceOf[TypeRefApi].args(0) + val rightType = outerType.asInstanceOf[TypeRefApi].args(1) + val leftBuf: TreeOrderedBuf[c.type] = dispatcher(leftType) + val rightBuf: TreeOrderedBuf[c.type] = dispatcher(rightType) + + def genBinaryCompare(inputStreamA: TermName, inputStreamB: TermName) = { + val valueOfA = freshT("valueOfA") + val valueOfB = freshT("valueOfB") + val tmpHolder = freshT("tmpHolder") + q""" + val $valueOfA = $inputStreamA.readByte + val $valueOfB = $inputStreamB.readByte + val $tmpHolder = _root_.java.lang.Byte.compare($valueOfA, $valueOfB) + if($tmpHolder != 0) { + //they are different, return comparison on type + $tmpHolder + } else if($valueOfA == (0: _root_.scala.Byte)) { + // they are both Left: + ${leftBuf.compareBinary(inputStreamA, inputStreamB)} + } else { + // they are both Right: + ${rightBuf.compareBinary(inputStreamA, inputStreamB)} + } + """ + } + + def genHashFn(element: TermName) = { + val innerValue = freshT("innerValue") + q""" + if($element.isLeft) { + val $innerValue = $element.left.get + val x = ${leftBuf.hash(innerValue)} + // x * (2^31 - 1) which is a mersenne prime + (x << 31) - x + } + else { + val $innerValue = $element.right.get + // x * (2^19 - 1) which is a mersenne prime + val x = ${rightBuf.hash(innerValue)} + (x << 19) - x + } + """ + } + + def genGetFn(inputStreamA: TermName) = { + val tmpGetHolder = freshT("tmpGetHolder") + q""" + val $tmpGetHolder = $inputStreamA.readByte + if($tmpGetHolder == (0: _root_.scala.Byte)) Left(${leftBuf.get(inputStreamA)}) + else Right(${rightBuf.get(inputStreamA)}) + """ + } + + def genPutFn(inputStream: TermName, element: TermName) = { + val tmpPutVal = freshT("tmpPutVal") + val innerValue = freshT("innerValue") + q""" + if($element.isRight) { + $inputStream.writeByte(1: _root_.scala.Byte) + val $innerValue = $element.right.get + ${rightBuf.put(inputStream, innerValue)} + } else { + $inputStream.writeByte(0: _root_.scala.Byte) + val $innerValue = $element.left.get + ${leftBuf.put(inputStream, innerValue)} + } + """ + } + + def genCompareFn(elementA: TermName, elementB: TermName) = { + val aIsRight = freshT("aIsRight") + val bIsRight = freshT("bIsRight") + val innerValueA = freshT("innerValueA") + val innerValueB = freshT("innerValueB") + q""" + val $aIsRight = $elementA.isRight + val $bIsRight = $elementB.isRight + if(!$aIsRight) { + if (!$bIsRight) { + val $innerValueA = $elementA.left.get + val $innerValueB = $elementB.left.get + ${leftBuf.compare(innerValueA, innerValueB)} + } + else -1 // Left(_) < Right(_) + } + else { + if(!$bIsRight) 1 // Right(_) > Left(_) + else { // both are right + val $innerValueA = $elementA.right.get + val $innerValueB = $elementB.right.get + ${rightBuf.compare(innerValueA, innerValueB)} + } + } + """ + } + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = genBinaryCompare(inputStreamA, inputStreamB) + override def hash(element: TermName): ctx.Tree = genHashFn(element) + override def put(inputStream: TermName, element: TermName) = genPutFn(inputStream, element) + override def get(inputStreamA: TermName): ctx.Tree = genGetFn(inputStreamA) + override def compare(elementA: TermName, elementB: TermName): ctx.Tree = genCompareFn(elementA, elementB) + override val lazyOuterVariables: Map[String, ctx.Tree] = + rightBuf.lazyOuterVariables ++ leftBuf.lazyOuterVariables + override def length(element: Tree): CompileTimeLengthTypes[c.type] = { + + def tree(ctl: CompileTimeLengthTypes[_]): c.Tree = ctl.asInstanceOf[CompileTimeLengthTypes[c.type]].t + val dyn = q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" + + (leftBuf.length(q"$element.left.get"), rightBuf.length(q"$element.right.get")) match { + case (lconst: ConstantLengthCalculation[_], rconst: ConstantLengthCalculation[_]) if lconst.toInt == rconst.toInt => + // We got lucky, they are the same size: + ConstantLengthCalculation(c)(1 + rconst.toInt) + case (_: NoLengthCalculationAvailable[_], _) => NoLengthCalculationAvailable(c) + case (_, _: NoLengthCalculationAvailable[_]) => NoLengthCalculationAvailable(c) + case (left: MaybeLengthCalculation[_], right: MaybeLengthCalculation[_]) => + MaybeLengthCalculation(c)(q""" + if ($element.isLeft) { ${tree(left)} + $dyn(1) } + else { ${tree(right)} + $dyn(1) } + """) + case (left: MaybeLengthCalculation[_], right) => + MaybeLengthCalculation(c)(q""" + if ($element.isLeft) { ${tree(left)} + $dyn(1) } + else { $dyn(${tree(right)}) + $dyn(1) } + """) + case (left, right: MaybeLengthCalculation[_]) => + MaybeLengthCalculation(c)(q""" + if ($element.isLeft) { $dyn(${tree(left)}) + $dyn(1) } + else { ${tree(right)} + $dyn(1) } + """) + // Rest are constant, but different values or fast. So the result is fast + case (left, right) => + // They are different sizes. :( + FastLengthCalculation(c)(q""" + if($element.isLeft) { 1 + ${tree(left)} } + else { 1 + ${tree(right)} } + """) + } + } + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala new file mode 100644 index 0000000000..056d7e79b3 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala @@ -0,0 +1,86 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.OrderedSerialization +import com.twitter.scalding.serialization.macros.impl.ordered_serialization._ + +/* + A fall back ordered bufferable to look for the user to have an implicit in scope to satisfy the missing + type. This is for the case where its an opaque class to our macros where we can't figure out the fields +*/ +object ImplicitOrderedBuf { + val macroMarker = "MACROASKEDORDEREDSER" + + def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + import c.universe._ + + val pf: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if !tpe.toString.contains(macroMarker) => ImplicitOrderedBuf(c)(tpe) + } + pf + } + + def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + + val variableID = (outerType.typeSymbol.fullName.hashCode.toLong + Int.MaxValue.toLong).toString + val variableNameStr = s"orderedSer_$variableID" + val variableName = newTermName(variableNameStr) + val typeAlias = newTypeName(c.fresh("MACROASKEDORDEREDSER")) + val implicitInstanciator = q""" + type $typeAlias = $outerType + implicitly[_root_.com.twitter.scalding.serialization.OrderedSerialization[$typeAlias]]""" + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + q"$variableName.compareBinary($inputStreamA, $inputStreamB).unsafeToInt" + override def hash(element: ctx.TermName): ctx.Tree = q"$variableName.hash($element)" + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + q"$variableName.write($inputStream, $element)" + + override def length(element: Tree) = + CompileTimeLengthTypes.MaybeLengthCalculation(c)(q""" + ($variableName.staticSize match { + case Some(s) => _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(s) + case None => + $variableName.dynamicSize($element) match { + case Some(s) => + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(s) + case None => + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation + } + }): _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength + """) + + override def get(inputStream: ctx.TermName): ctx.Tree = + q"$variableName.read($inputStream).get" + + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + q"$variableName.compare($elementA, $elementB)" + override val lazyOuterVariables = Map(variableNameStr -> implicitInstanciator) + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala new file mode 100644 index 0000000000..7d2c1403b8 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala @@ -0,0 +1,148 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import com.twitter.scalding.serialization.OrderedSerialization + +object OptionOrderedBuf { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe.erasure =:= c.universe.typeOf[Option[Any]] => OptionOrderedBuf(c)(buildDispatcher, tpe) + } + + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + val dispatcher = buildDispatcher + + val innerType = outerType.asInstanceOf[TypeRefApi].args.head + val innerBuf: TreeOrderedBuf[c.type] = dispatcher(innerType) + + def genBinaryCompare(inputStreamA: TermName, inputStreamB: TermName) = { + val valueOfA = freshT("valueOfA") + val valueOfB = freshT("valueOfB") + val tmpHolder = freshT("tmpHolder") + q""" + val $valueOfA = $inputStreamA.readByte + val $valueOfB = $inputStreamB.readByte + val $tmpHolder = _root_.java.lang.Byte.compare($valueOfA, $valueOfB) + if($tmpHolder != 0 || $valueOfA == (0: _root_.scala.Byte)) { + //either one is defined (different), or both are None (equal) + $tmpHolder + } else { + ${innerBuf.compareBinary(inputStreamA, inputStreamB)} + } + """ + } + + def genHashFn(element: TermName) = { + val innerValue = freshT("innerValue") + q""" + if($element.isEmpty) + 0 + else { + val $innerValue = $element.get + ${innerBuf.hash(innerValue)} + } + """ + } + + def genGetFn(inputStreamA: TermName) = { + val tmpGetHolder = freshT("tmpGetHolder") + q""" + val $tmpGetHolder = $inputStreamA.readByte + if($tmpGetHolder == (0: _root_.scala.Byte)) None + else Some(${innerBuf.get(inputStreamA)}) + """ + } + + def genPutFn(inputStream: TermName, element: TermName) = { + val tmpPutVal = freshT("tmpPutVal") + val innerValue = freshT("innerValue") + q""" + if($element.isDefined) { + $inputStream.writeByte(1: _root_.scala.Byte) + val $innerValue = $element.get + ${innerBuf.put(inputStream, innerValue)} + } else { + $inputStream.writeByte(0: _root_.scala.Byte) + } + """ + } + + def genCompareFn(elementA: TermName, elementB: TermName) = { + val aIsDefined = freshT("aIsDefined") + val bIsDefined = freshT("bIsDefined") + val innerValueA = freshT("innerValueA") + val innerValueB = freshT("innerValueB") + q""" + val $aIsDefined = $elementA.isDefined + val $bIsDefined = $elementB.isDefined + if(!$aIsDefined) { + if (!$bIsDefined) 0 // None == None + else -1 // None < Some(_) + } + else { + if(!$bIsDefined) 1 // Some > None + else { // both are defined + val $innerValueA = $elementA.get + val $innerValueB = $elementB.get + ${innerBuf.compare(innerValueA, innerValueB)} + } + } + """ + } + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = genBinaryCompare(inputStreamA, inputStreamB) + override def hash(element: TermName): ctx.Tree = genHashFn(element) + override def put(inputStream: TermName, element: TermName) = genPutFn(inputStream, element) + override def get(inputStreamA: TermName): ctx.Tree = genGetFn(inputStreamA) + override def compare(elementA: TermName, elementB: TermName): ctx.Tree = genCompareFn(elementA, elementB) + override val lazyOuterVariables: Map[String, ctx.Tree] = innerBuf.lazyOuterVariables + override def length(element: Tree): CompileTimeLengthTypes[c.type] = { + innerBuf.length(q"$element.get") match { + case const: ConstantLengthCalculation[_] => FastLengthCalculation(c)(q""" + if($element.isDefined) { 1 + ${const.toInt} } + else { 1 } + """) + case f: FastLengthCalculation[_] => + val t = f.asInstanceOf[FastLengthCalculation[c.type]].t + FastLengthCalculation(c)(q""" + if($element.isDefined) { 1 + $t } + else { 1 } + """) + case m: MaybeLengthCalculation[_] => + val t = m.asInstanceOf[MaybeLengthCalculation[c.type]].t + val dynlen = q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" + MaybeLengthCalculation(c)(q""" + if ($element.isDefined) { $t + $dynlen(1) } + else { $dynlen(1) } + """) + case _ => NoLengthCalculationAvailable(c) + } + } + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala new file mode 100644 index 0000000000..41b7584763 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala @@ -0,0 +1,116 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import java.nio.ByteBuffer +import com.twitter.scalding.serialization.OrderedSerialization + +object PrimitiveOrderedBuf { + def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe =:= c.universe.typeOf[Boolean] => + PrimitiveOrderedBuf(c)(tpe, "Boolean", 1, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Boolean] => + PrimitiveOrderedBuf(c)(tpe, "Boolean", 1, true) + case tpe if tpe =:= c.universe.typeOf[Byte] => + PrimitiveOrderedBuf(c)(tpe, "Byte", 1, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Byte] => + PrimitiveOrderedBuf(c)(tpe, "Byte", 1, true) + case tpe if tpe =:= c.universe.typeOf[Short] => + PrimitiveOrderedBuf(c)(tpe, "Short", 2, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Short] => + PrimitiveOrderedBuf(c)(tpe, "Short", 2, true) + case tpe if tpe =:= c.universe.typeOf[Char] => + PrimitiveOrderedBuf(c)(tpe, "Character", 2, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Character] => + PrimitiveOrderedBuf(c)(tpe, "Character", 2, true) + case tpe if tpe =:= c.universe.typeOf[Int] => + PrimitiveOrderedBuf(c)(tpe, "Integer", 4, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Integer] => + PrimitiveOrderedBuf(c)(tpe, "Integer", 4, true) + case tpe if tpe =:= c.universe.typeOf[Long] => + PrimitiveOrderedBuf(c)(tpe, "Long", 8, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Long] => + PrimitiveOrderedBuf(c)(tpe, "Long", 8, true) + case tpe if tpe =:= c.universe.typeOf[Float] => + PrimitiveOrderedBuf(c)(tpe, "Float", 4, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Float] => + PrimitiveOrderedBuf(c)(tpe, "Float", 4, true) + case tpe if tpe =:= c.universe.typeOf[Double] => + PrimitiveOrderedBuf(c)(tpe, "Double", 8, false) + case tpe if tpe =:= c.universe.typeOf[java.lang.Double] => + PrimitiveOrderedBuf(c)(tpe, "Double", 8, true) + } + + def apply(c: Context)(outerType: c.Type, + javaTypeStr: String, + lenInBytes: Int, + boxed: Boolean): TreeOrderedBuf[c.type] = { + import c.universe._ + val javaType = newTermName(javaTypeStr) + + def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + + val shortName: String = Map("Integer" -> "Int", "Character" -> "Char") + .getOrElse(javaTypeStr, javaTypeStr) + + val bbGetter = newTermName("read" + shortName) + val bbPutter = newTermName("write" + shortName) + + def genBinaryCompare(inputStreamA: TermName, inputStreamB: TermName): Tree = + q"""_root_.java.lang.$javaType.compare($inputStreamA.$bbGetter, $inputStreamB.$bbGetter)""" + + def accessor(e: c.TermName): c.Tree = { + val primitiveAccessor = newTermName(shortName.toLowerCase + "Value") + if (boxed) q"$e.$primitiveAccessor" + else q"$e" + } + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + genBinaryCompare(inputStreamA, inputStreamB) + override def hash(element: ctx.TermName): ctx.Tree = { + // This calls out the correctly named item in Hasher + val typeLowerCase = newTermName(javaTypeStr.toLowerCase) + q"_root_.com.twitter.scalding.serialization.Hasher.$typeLowerCase.hash(${accessor(element)})" + } + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + q"$inputStream.$bbPutter(${accessor(element)})" + + override def get(inputStream: ctx.TermName): ctx.Tree = { + val unboxed = q"$inputStream.$bbGetter" + if (boxed) q"_root_.java.lang.$javaType.valueOf($unboxed)" else unboxed + } + + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + if (boxed) q"""$elementA.compareTo($elementB)""" + else q"""_root_.java.lang.$javaType.compare($elementA, $elementB)""" + + override def length(element: Tree): CompileTimeLengthTypes[c.type] = + ConstantLengthCalculation(c)(lenInBytes) + + override val lazyOuterVariables: Map[String, ctx.Tree] = Map.empty + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala new file mode 100644 index 0000000000..4c534bce11 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala @@ -0,0 +1,123 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import java.nio.ByteBuffer +import com.twitter.scalding.serialization.OrderedSerialization + +object ProductOrderedBuf { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + import c.universe._ + val validTypes: List[Type] = List(typeOf[Product1[Any]], + typeOf[Product2[Any, Any]], + typeOf[Product3[Any, Any, Any]], + typeOf[Product4[Any, Any, Any, Any]], + typeOf[Product5[Any, Any, Any, Any, Any]], + typeOf[Product6[Any, Any, Any, Any, Any, Any]], + typeOf[Product7[Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product8[Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product9[Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[Product22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]) + + def validType(curType: Type): Boolean = { + validTypes.find{ t => curType <:< t }.isDefined + } + + def symbolFor(subType: Type): Type = { + val superType = validTypes.find{ t => subType.erasure <:< t }.get + subType.baseType(superType.typeSymbol) + } + + val pf: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if validType(tpe.erasure) => ProductOrderedBuf(c)(buildDispatcher, tpe, symbolFor(tpe)) + } + pf + } + + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], originalType: c.Type, outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(id)) + + val dispatcher = buildDispatcher + val elementData: List[(c.universe.Type, TermName, TreeOrderedBuf[c.type])] = + outerType + .declarations + .collect { case m: MethodSymbol => m } + .filter(m => m.name.toTermName.toString.startsWith("_")) + .map { accessorMethod => + val fieldType = accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) + val b: TreeOrderedBuf[c.type] = dispatcher(fieldType) + (fieldType, accessorMethod.name.toTermName, b) + }.toList + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + ProductLike.compareBinary(c)(inputStreamA, inputStreamB)(elementData) + + override def hash(element: ctx.TermName): ctx.Tree = ProductLike.hash(c)(element)(elementData) + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + ProductLike.put(c)(inputStream, element)(elementData) + + override def get(inputStream: ctx.TermName): ctx.Tree = { + + val getValProcessor = elementData.map { + case (tpe, accessorSymbol, tBuf) => + val curR = freshT("curR") + val builderTree = q""" + val $curR = { + ${tBuf.get(inputStream)} + } + """ + (builderTree, curR) + } + q""" + ..${getValProcessor.map(_._1)} + new ${originalType}(..${getValProcessor.map(_._2)}) + """ + } + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + ProductLike.compare(c)(elementA, elementB)(elementData) + + override val lazyOuterVariables: Map[String, ctx.Tree] = + elementData.map(_._3.lazyOuterVariables).reduce(_ ++ _) + + override def length(element: Tree) = + ProductLike.length(c)(element)(elementData) + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala new file mode 100644 index 0000000000..ac00d69680 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala @@ -0,0 +1,126 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import java.nio.ByteBuffer +import com.twitter.scalding.serialization.OrderedSerialization + +object StringOrderedBuf { + def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe =:= c.universe.typeOf[String] => StringOrderedBuf(c)(tpe) + } + + def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + + def freshT(id: String) = newTermName(c.fresh(id)) + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = { + val lenA = freshT("lenA") + val lenB = freshT("lenB") + + q""" + val $lenA = $inputStreamA.readPosVarInt + val $lenB = $inputStreamB.readPosVarInt + _root_.com.twitter.scalding.serialization.StringOrderedSerialization.binaryIntCompare($lenA, + $inputStreamA, + $lenB, + $inputStreamB) + """ + } + + override def hash(element: ctx.TermName): ctx.Tree = q"_root_.com.twitter.scalding.serialization.Hasher.string.hash($element)" + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = { + val bytes = freshT("bytes") + val charLen = freshT("charLen") + val len = freshT("len") + q""" + // Ascii is very common, so if the string is short, + // we check if it is ascii: + def isShortAscii(size: Int, str: String): Boolean = (size < 65) && { + var pos = 0 + var ascii: Boolean = true + while((pos < size) && ascii) { + ascii = (str.charAt(pos) < 128) + pos += 1 + } + ascii + } + + val $charLen = $element.length + if ($charLen == 0) { + $inputStream.writePosVarInt(0) + } + else if (isShortAscii($charLen, $element)) { + $inputStream.writePosVarInt($charLen) + val $bytes = new Array[Byte]($charLen) + // This deprecated gets ascii bytes out, but is incorrect + // for non-ascii data. + $element.getBytes(0, $charLen, $bytes, 0) + $inputStream.write($bytes) + } + else { + // Just use utf-8 + // TODO: investigate faster ways to encode UTF-8, if + // the bug that makes string Charsets faster than using Charset instances. + // see for instance: + // http://psy-lob-saw.blogspot.com/2012/12/encode-utf-8-string-to-bytebuffer-faster.html + val $bytes = $element.getBytes("UTF-8") + val $len = $bytes.length + $inputStream.writePosVarInt($len) + $inputStream.write($bytes) + } + """ + } + override def get(inputStream: ctx.TermName): ctx.Tree = { + val len = freshT("len") + val strBytes = freshT("strBytes") + q""" + val $len = $inputStream.readPosVarInt + if($len > 0) { + val $strBytes = new Array[Byte]($len) + $inputStream.readFully($strBytes) + new String($strBytes, "UTF-8") + } else { + "" + } + """ + } + override def compare(elementA: ctx.TermName, elementB: ctx.TermName) = + q"""$elementA.compareTo($elementB)""" + + override val lazyOuterVariables: Map[String, ctx.Tree] = Map.empty + override def length(element: Tree): CompileTimeLengthTypes[c.type] = MaybeLengthCalculation(c)(q""" + if($element.isEmpty) { + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(1) + } else { + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation + } + """) + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala new file mode 100644 index 0000000000..b552645cb5 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala @@ -0,0 +1,303 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context +import java.io.InputStream + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import com.twitter.scalding.serialization.OrderedSerialization +import scala.reflect.ClassTag + +import scala.{ collection => sc } +import scala.collection.{ immutable => sci } + +sealed trait ShouldSort +case object DoSort extends ShouldSort +case object NoSort extends ShouldSort + +sealed trait MaybeArray +case object IsArray extends MaybeArray +case object NotArray extends MaybeArray + +object TraversablesOrderedBuf { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe.erasure =:= c.universe.typeOf[Iterable[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Iterable[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[List[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.List[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Vector[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Vector[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[IndexedSeq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.IndexedSeq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Queue[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + // Arrays are special in that the erasure doesn't do anything + case tpe if tpe.typeSymbol == c.universe.typeOf[Array[Any]].typeSymbol => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, IsArray) + // The erasure of a non-covariant is Set[_], so we need that here for sets + case tpe if tpe.erasure =:= c.universe.typeOf[Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashSet[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListSet[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + + case tpe if tpe.erasure =:= c.universe.typeOf[Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashMap[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListMap[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + } + + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + outerType: c.Type, + maybeSort: ShouldSort, + maybeArray: MaybeArray): TreeOrderedBuf[c.type] = { + + import c.universe._ + def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + + val dispatcher = buildDispatcher + + val companionSymbol = outerType.typeSymbol.companionSymbol + + // When dealing with a map we have 2 type args, and need to generate the tuple type + // it would correspond to if we .toList the Map. + val innerType = if (outerType.asInstanceOf[TypeRefApi].args.size == 2) { + val (tpe1, tpe2) = (outerType.asInstanceOf[TypeRefApi].args(0), outerType.asInstanceOf[TypeRefApi].args(1)) + val containerType = typeOf[Tuple2[Any, Any]].asInstanceOf[TypeRef] + import compat._ + TypeRef.apply(containerType.pre, containerType.sym, List(tpe1, tpe2)) + } else { + outerType.asInstanceOf[TypeRefApi].args.head + } + + val innerTypes = outerType.asInstanceOf[TypeRefApi].args + + val innerBuf: TreeOrderedBuf[c.type] = dispatcher(innerType) + // TODO it would be nice to capture one instance of this rather + // than allocate in every call in the materialized class + val ioa = freshT("ioa") + val iob = freshT("iob") + val innerOrd = q""" + new _root_.scala.math.Ordering[${innerBuf.tpe}] { + def compare(a: ${innerBuf.tpe}, b: ${innerBuf.tpe}) = { + val $ioa = a + val $iob = b + ${innerBuf.compare(ioa, iob)} + } + } + """ + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = { + val innerCompareFn = freshT("innerCompareFn") + val a = freshT("a") + val b = freshT("b") + q""" + val $innerCompareFn = { (a: _root_.java.io.InputStream, b: _root_.java.io.InputStream) => + val $a = a + val $b = b + ${innerBuf.compareBinary(a, b)} + }; + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.rawCompare($inputStreamA, $inputStreamB)($innerCompareFn) + """ + } + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = { + val asArray = freshT("asArray") + val bytes = freshT("bytes") + val len = freshT("len") + val pos = freshT("pos") + val innerElement = freshT("innerElement") + val cmpRes = freshT("cmpRes") + + maybeSort match { + case DoSort => + q""" + val $len = $element.size + $inputStream.writePosVarInt($len) + + if($len > 0) { + val $asArray = $element.toArray[${innerBuf.tpe}] + // Sorting on the in-memory is the same as binary + _root_.scala.util.Sorting.quickSort[${innerBuf.tpe}]($asArray)($innerOrd) + var $pos = 0 + while($pos < $len) { + val $innerElement = $asArray($pos) + ${innerBuf.put(inputStream, innerElement)} + $pos += 1 + } + } + """ + case NoSort => + q""" + val $len: Int = $element.size + $inputStream.writePosVarInt($len) + $element.foreach { case $innerElement => + ${innerBuf.put(inputStream, innerElement)} + } + """ + } + + } + override def hash(element: ctx.TermName): ctx.Tree = { + val currentHash = freshT("currentHash") + val len = freshT("len") + val target = freshT("target") + maybeSort match { + case NoSort => + q""" + var $currentHash: Int = _root_.com.twitter.scalding.serialization.MurmurHashUtils.seed + var $len = 0 + $element.foreach { t => + val $target = t + $currentHash = + _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${innerBuf.hash(target)}) + // go ahead and compute the length so we don't traverse twice for lists + $len += 1 + } + _root_.com.twitter.scalding.serialization.MurmurHashUtils.fmix($currentHash, $len) + """ + case DoSort => + // We actually don't sort here, which would be expensive, but combine with a commutative operation + // so the order that we see items won't matter. For this we use XOR + q""" + var $currentHash: Int = _root_.com.twitter.scalding.serialization.MurmurHashUtils.seed + var $len = 0 + $element.foreach { t => + val $target = t + $currentHash = $currentHash ^ ${innerBuf.hash(target)} + $len += 1 + } + // Might as well be fancy when we mix in the length + _root_.com.twitter.scalding.serialization.MurmurHashUtils.fmix($currentHash, $len) + """ + } + } + + override def get(inputStream: ctx.TermName): ctx.Tree = { + val len = freshT("len") + val firstVal = freshT("firstVal") + val travBuilder = freshT("travBuilder") + val iter = freshT("iter") + val extractionTree = maybeArray match { + case IsArray => + q"""val $travBuilder = new Array[..$innerTypes]($len) + var $iter = 0 + while($iter < $len) { + $travBuilder($iter) = ${innerBuf.get(inputStream)} + $iter = $iter + 1 + } + $travBuilder : $outerType + """ + case NotArray => + q"""val $travBuilder = $companionSymbol.newBuilder[..$innerTypes] + $travBuilder.sizeHint($len) + var $iter = 0 + while($iter < $len) { + $travBuilder += ${innerBuf.get(inputStream)} + $iter = $iter + 1 + } + $travBuilder.result : $outerType + """ + } + q""" + val $len: Int = $inputStream.readPosVarInt + if($len > 0) { + if($len == 1) { + val $firstVal: $innerType = ${innerBuf.get(inputStream)} + $companionSymbol.apply($firstVal) : $outerType + } else { + $extractionTree : $outerType + } + } else { + $companionSymbol.empty : $outerType + } + """ + } + + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = { + + val a = freshT("a") + val b = freshT("b") + val cmpFnName = freshT("cmpFnName") + maybeSort match { + case DoSort => + q""" + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.sortedCompare[${innerBuf.tpe}]($elementA, $elementB)($innerOrd) + """ + + case NoSort => + q""" + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.iteratorCompare[${innerBuf.tpe}]($elementA.iterator, $elementB.iterator)($innerOrd) + """ + } + + } + + override val lazyOuterVariables: Map[String, ctx.Tree] = innerBuf.lazyOuterVariables + + override def length(element: Tree): CompileTimeLengthTypes[c.type] = { + + innerBuf.length(q"$element.head") match { + case const: ConstantLengthCalculation[_] => + FastLengthCalculation(c)(q"""{ + posVarIntSize($element.size) + $element.size * ${const.toInt} + }""") + case m: MaybeLengthCalculation[_] => + val maybeRes = freshT("maybeRes") + MaybeLengthCalculation(c)(q""" + if($element.isEmpty) { + val sizeOfZero = 1 // writing the constant 0, for length, takes 1 byte + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(sizeOfZero) + } else { + val maybeRes = ${m.asInstanceOf[MaybeLengthCalculation[c.type]].t} + maybeRes match { + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(constSize) => + val sizeOverhead = posVarIntSize($element.size) + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(constSize * $element.size + sizeOverhead) + + // todo maybe we should support this case + // where we can visit every member of the list relatively fast to ask + // its length. Should we care about sizes instead maybe? + case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(_) => + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation + case _ => _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation + } + } + """) + // Something we can't workout the size of ahead of time + case _ => MaybeLengthCalculation(c)(q""" + if($element.isEmpty) { + val sizeOfZero = 1 // writing the constant 0, for length, takes 1 byte + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(sizeOfZero) + } else { + _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation + } + """) + } + } + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala new file mode 100644 index 0000000000..e0cedb05a9 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala @@ -0,0 +1,62 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers + +import scala.language.experimental.macros +import scala.reflect.macros.Context + +import com.twitter.scalding._ +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import CompileTimeLengthTypes._ +import java.nio.ByteBuffer +import com.twitter.scalding.serialization.OrderedSerialization + +object UnitOrderedBuf { + def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if tpe =:= c.universe.typeOf[Unit] => UnitOrderedBuf(c)(tpe) + } + + def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { + import c.universe._ + + new TreeOrderedBuf[c.type] { + override val ctx: c.type = c + override val tpe = outerType + + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + q"0" + + override def hash(element: ctx.TermName): ctx.Tree = + q"0" + + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + q"()" + + override def get(inputStreamA: ctx.TermName): ctx.Tree = + q"()" + + def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + q"0" + + override def length(element: Tree): CompileTimeLengthTypes[c.type] = + ConstantLengthCalculation(c)(0) + + override val lazyOuterVariables: Map[String, ctx.Tree] = + Map.empty + } + } +} + diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala new file mode 100644 index 0000000000..137c82060b --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala @@ -0,0 +1,42 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers + +/** + * There is a Monoid on MaybeLength, with + * ConstLen(0) being the zero. + */ +sealed trait MaybeLength { + def +(that: MaybeLength): MaybeLength +} + +case object NoLengthCalculation extends MaybeLength { + def +(that: MaybeLength): MaybeLength = this +} +case class ConstLen(toInt: Int) extends MaybeLength { + def +(that: MaybeLength): MaybeLength = that match { + case ConstLen(c) => ConstLen(toInt + c) + case DynamicLen(d) => DynamicLen(toInt + d) + case NoLengthCalculation => NoLengthCalculation + } +} +case class DynamicLen(toInt: Int) extends MaybeLength { + def +(that: MaybeLength): MaybeLength = that match { + case ConstLen(c) => DynamicLen(toInt + c) + case DynamicLen(d) => DynamicLen(toInt + d) + case NoLengthCalculation => NoLengthCalculation + } +} diff --git a/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala new file mode 100644 index 0000000000..8e9961e9f4 --- /dev/null +++ b/scalding-serialization-macros/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala @@ -0,0 +1,174 @@ +/* + Copyright 2014 Twitter, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers + +import java.io.InputStream +import scala.collection.mutable.Buffer + +object TraversableHelpers { + import com.twitter.scalding.serialization.JavaStreamEnrichments._ + + final def rawCompare(inputStreamA: InputStream, inputStreamB: InputStream)(consume: (InputStream, InputStream) => Int): Int = { + val lenA = inputStreamA.readPosVarInt + val lenB = inputStreamB.readPosVarInt + + val minLen = math.min(lenA, lenB) + var incr = 0 + var curIncr = 0 + while (incr < minLen && curIncr == 0) { + curIncr = consume(inputStreamA, inputStreamB) + incr = incr + 1 + } + + if (curIncr != 0) curIncr + else java.lang.Integer.compare(lenA, lenB) + } + + final def iteratorCompare[T](iteratorA: Iterator[T], iteratorB: Iterator[T])(implicit ord: Ordering[T]): Int = { + @annotation.tailrec + def result: Int = + if (iteratorA.isEmpty) { + if (iteratorB.isEmpty) 0 + else -1 // a is shorter + } else { + if (iteratorB.isEmpty) 1 // a is longer + else { + val cmp = ord.compare(iteratorA.next, iteratorB.next) + if (cmp != 0) cmp + else result + } + } + + result + } + + final def iteratorEquiv[T](iteratorA: Iterator[T], iteratorB: Iterator[T])(implicit eq: Equiv[T]): Boolean = { + @annotation.tailrec + def result: Boolean = + if (iteratorA.isEmpty) iteratorB.isEmpty + else if (iteratorB.isEmpty) false // not empty != empty + else eq.equiv(iteratorA.next, iteratorB.next) && result + + result + } + /** + * This returns the same result as + * + * implicit val o = ord + * Ordering[Iterable[T]].compare(travA.toList.sorted, travB.toList.sorted) + * + * but it does not do a full sort. Instead it uses a partial quicksort approach + * the complexity should be O(N + M) rather than O(N log N + M log M) for the full + * sort case + */ + final def sortedCompare[T](travA: Iterable[T], travB: Iterable[T])(implicit ord: Ordering[T]): Int = { + def compare(startA: Int, endA: Int, a: Buffer[T], startB: Int, endB: Int, b: Buffer[T]): Int = + if (startA == endA) { + if (startB == endB) 0 // both empty + else -1 // empty is smaller than non-empty + } else if (startB == endB) 1 // non-empty is bigger than empty + else { + @annotation.tailrec + def partition(pivot: T, pivotStart: Int, pivotEnd: Int, endX: Int, x: Buffer[T]): (Int, Int) = { + if (pivotEnd >= endX) (pivotStart, pivotEnd) + else { + val t = x(pivotEnd) + val cmp = ord.compare(t, pivot) + if (cmp == 0) { + // the pivot section grows by 1 to include test + partition(pivot, pivotStart, pivotEnd + 1, endX, x) + } else if (cmp > 0) { + // test is bigger, swap it with the end and move the end down: + val newEnd = endX - 1 + val end = x(newEnd) + x(newEnd) = t + x(pivotEnd) = end + // now try again: + partition(pivot, pivotStart, pivotEnd, newEnd, x) + } else { + // t < pivot so we need to push this value below the pivots: + val ps = x(pivotStart) // might not be pivot if the pivot size is 0 + x(pivotStart) = t + x(pivotEnd) = ps + partition(pivot, pivotStart + 1, pivotEnd + 1, endX, x) + } + } + } + val pivot = a(startA) + val (aps, ape) = partition(pivot, startA, startA + 1, endA, a) + val (bps, bpe) = partition(pivot, startB, startB, endB, b) + + val asublen = aps - startA + val bsublen = bps - startB + if (asublen != bsublen) { + // comparing to the longer is enough + // because one of them will then include pivots which are larger + val longer = math.max(asublen, bsublen) + def extend(s: Int, e: Int) = math.min(s + longer, e) + + if (asublen != 0) { + /* + * We can safely recurse because startA does not hold pivot, so we won't + * do the same algorithm + */ + compare(startA, extend(startA, endA), a, startB, extend(startB, endB), b) + } else { + /* + * We know that startB does not have the pivot, because if it did, bsublen == 0 + * and both are equal, which is not true in this branch. + * we can reverse the recursion to ensure we get a different pivot + */ + -compare(startB, extend(startB, endB), b, startA, extend(startA, endA), a) + } + } else { + // the prefixes are the same size + val cmp = compare(startA, aps, a, startB, bps, b) + if (cmp != 0) cmp + else { + // we need to look into the pivot area. We don't need to check + // for equality on the overlapped pivot range + val apsize = ape - aps + val bpsize = bpe - bps + val minpsize = math.min(apsize, bpsize) + val acheck = aps + minpsize + val bcheck = bps + minpsize + if (apsize != bpsize && + acheck < endA && + bcheck < endB) { + // exactly one of them has a pivot value + ord.compare(a(acheck), b(bcheck)) + } else { + // at least one of them or both is empty, and we pick it up above + compare(aps + minpsize, endA, a, bps + minpsize, endB, b) + } + } + } + } + + /** + * If we are equal unsorted, we are equal. + * this is useful because often scala will build identical sets + * exactly the same way, so this fast check will work. + */ + if (iteratorEquiv(travA.iterator, travB.iterator)(ord)) 0 + else { + // Let's do the more expensive, potentially full sort, algorithm + val a = travA.toBuffer + val b = travB.toBuffer + compare(0, a.size, a, 0, b.size, b) + } + } +} diff --git a/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala b/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala new file mode 100644 index 0000000000..ee0bc5fc8f --- /dev/null +++ b/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala @@ -0,0 +1,533 @@ +/* +Copyright 2012 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization.macros + +import org.scalatest.{ FunSuite, ShouldMatchers } +import org.scalatest.prop.Checkers +import org.scalatest.prop.PropertyChecks +import scala.language.experimental.macros +import com.twitter.scalding.serialization.{ OrderedSerialization, Law, Law1, Law2, Law3, Serialization } +import java.nio.ByteBuffer +import org.scalacheck.Arbitrary.{ arbitrary => arb } +import java.io.{ ByteArrayOutputStream, InputStream } + +import org.scalacheck.{ Arbitrary, Gen, Prop } +import com.twitter.scalding.serialization.JavaStreamEnrichments + +import scala.collection.immutable.Queue + +trait LowerPriorityImplicit { + implicit def primitiveOrderedBufferSupplier[T] = macro impl.OrderedSerializationProviderImpl[T] +} + +object LawTester { + def apply[T: Arbitrary](laws: Iterable[Law[T]]): Prop = + apply(implicitly[Arbitrary[T]].arbitrary, laws) + + def apply[T](g: Gen[T], laws: Iterable[Law[T]]): Prop = + laws.foldLeft(true: Prop) { + case (soFar, Law1(name, fn)) => soFar && Prop.forAll(g)(fn).label(name) + case (soFar, Law2(name, fn)) => soFar && Prop.forAll(g, g)(fn).label(name) + case (soFar, Law3(name, fn)) => soFar && Prop.forAll(g, g, g)(fn).label(name) + } +} + +object ByteBufferArb { + implicit def arbitraryTestTypes: Arbitrary[ByteBuffer] = Arbitrary { + for { + aBinary <- Gen.alphaStr.map(s => ByteBuffer.wrap(s.getBytes("UTF-8"))) + } yield aBinary + } +} +object TestCC { + import ByteBufferArb._ + implicit def arbitraryTestCC: Arbitrary[TestCC] = Arbitrary { + for { + aInt <- arb[Int] + aLong <- arb[Long] + aDouble <- arb[Double] + anOption <- arb[Option[Int]] + anStrOption <- arb[Option[String]] + anOptionOfAListOfStrings <- arb[Option[List[String]]] + aBB <- arb[ByteBuffer] + } yield TestCC(aInt, aLong, anOption, aDouble, anStrOption, anOptionOfAListOfStrings, aBB) + } +} +case class TestCC(a: Int, b: Long, c: Option[Int], d: Double, e: Option[String], f: Option[List[String]], aBB: ByteBuffer) + +object MyData { + implicit def arbitraryTestCC: Arbitrary[MyData] = Arbitrary { + for { + aInt <- arb[Int] + anOption <- arb[Option[Long]] + } yield new MyData(aInt, anOption) + } +} + +class MyData(override val _1: Int, override val _2: Option[Long]) extends Product2[Int, Option[Long]] { + override def canEqual(that: Any): Boolean = that match { + case o: MyData => true + case _ => false + } +} + +object MacroOpaqueContainer { + def getOrdSer[T]: OrderedSerialization[T] = macro impl.OrderedSerializationProviderImpl[T] + import java.io._ + implicit val myContainerOrderedSerializer = new OrderedSerialization[MacroOpaqueContainer] { + val intOrderedSerialization = getOrdSer[Int] + + override def hash(s: MacroOpaqueContainer) = intOrderedSerialization.hash(s.myField) ^ Int.MaxValue + override def compare(a: MacroOpaqueContainer, b: MacroOpaqueContainer) = intOrderedSerialization.compare(a.myField, b.myField) + + override def read(in: InputStream) = intOrderedSerialization.read(in).map(MacroOpaqueContainer(_)) + + override def write(b: OutputStream, s: MacroOpaqueContainer) = intOrderedSerialization.write(b, s.myField) + + override def compareBinary(lhs: InputStream, rhs: InputStream) = intOrderedSerialization.compareBinary(lhs, rhs) + override val staticSize = Some(4) + + override def dynamicSize(i: MacroOpaqueContainer) = staticSize + } + + implicit def arbitraryMacroOpaqueContainer: Arbitrary[MacroOpaqueContainer] = Arbitrary { + for { + aInt <- arb[Int] + } yield MacroOpaqueContainer(aInt) + } + + def apply(d: Int): MacroOpaqueContainer = new MacroOpaqueContainer(d) +} + +class MacroOpaqueContainer(val myField: Int) { +} + +object Container { + implicit def arbitraryInnerCaseClass: Arbitrary[InnerCaseClass] = Arbitrary { + for { + anOption <- arb[Set[Double]] + } yield InnerCaseClass(anOption) + } + + type SetAlias = Set[Double] + case class InnerCaseClass(e: SetAlias) +} +class MacroOrderingProperties extends FunSuite with PropertyChecks with ShouldMatchers with LowerPriorityImplicit { + type SetAlias = Set[Double] + + import ByteBufferArb._ + import Container.arbitraryInnerCaseClass + + import OrderedSerialization.{ compare => oBufCompare } + + def gen[T: Arbitrary]: Gen[T] = implicitly[Arbitrary[T]].arbitrary + + def arbMap[T: Arbitrary, U](fn: T => U): Arbitrary[U] = Arbitrary(gen[T].map(fn)) + + def collectionArb[C[_], T: Arbitrary](implicit cbf: collection.generic.CanBuildFrom[Nothing, T, C[T]]): Arbitrary[C[T]] = Arbitrary { + gen[List[T]].map { l => + val builder = cbf() + l.foreach { builder += _ } + builder.result + } + } + + def serialize[T](t: T)(implicit orderedBuffer: OrderedSerialization[T]): InputStream = + serializeSeq(List(t)) + + def serializeSeq[T](t: Seq[T])(implicit orderedBuffer: OrderedSerialization[T]): InputStream = { + import JavaStreamEnrichments._ + + val baos = new ByteArrayOutputStream + t.foreach({ e => + orderedBuffer.write(baos, e) + }) + baos.toInputStream + } + + def rt[T](t: T)(implicit orderedBuffer: OrderedSerialization[T]) = { + val buf = serialize[T](t) + orderedBuffer.read(buf).get + } + + def rawCompare[T](a: T, b: T)(implicit obuf: OrderedSerialization[T]): Int = + obuf.compareBinary(serialize(a), serialize(b)).unsafeToInt + + def checkManyExplicit[T](i: List[(T, T)])(implicit obuf: OrderedSerialization[T]) = { + val serializedA = serializeSeq(i.map(_._1)) + val serializedB = serializeSeq(i.map(_._2)) + i.foreach { + case (a, b) => + val compareBinary = obuf.compareBinary(serializedA, serializedB).unsafeToInt + val compareMem = obuf.compare(a, b) + if (compareBinary < 0) { + assert(compareMem < 0, s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") + } else if (compareBinary > 0) { + assert(compareMem > 0, s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") + } + } + } + + def checkMany[T: Arbitrary](implicit obuf: OrderedSerialization[T]) = forAll { i: List[(T, T)] => + checkManyExplicit(i) + } + + def checkWithInputs[T](a: T, b: T)(implicit obuf: OrderedSerialization[T]) { + val rta = rt(a) // before we do anything ensure these don't throw + val rtb = rt(b) // before we do anything ensure these don't throw + val asize = Serialization.toBytes(a).length + if (obuf.dynamicSize(a).isDefined) { + assert(obuf.dynamicSize(a).get == asize, "dynamic size matches the correct value") + } + if (obuf.staticSize.isDefined) { + assert(obuf.dynamicSize(a).get == asize, "dynamic size matches the correct value") + assert(obuf.staticSize.get == asize, "dynamic size matches the correct value") + } + assert(oBufCompare(rta, a) === 0, s"A should be equal to itself after an RT -- ${rt(a)}") + assert(oBufCompare(rtb, b) === 0, s"B should be equal to itself after an RT-- ${rt(b)}") + assert(oBufCompare(a, b) + oBufCompare(b, a) === 0, "In memory comparasons make sense") + assert(rawCompare(a, b) + rawCompare(b, a) === 0, "When adding the raw compares in inverse order they should sum to 0") + assert(oBufCompare(rta, rtb) === oBufCompare(a, b), "Comparing a and b with ordered bufferables compare after a serialization RT") + } + + def checkAreSame[T](a: T, b: T)(implicit obuf: OrderedSerialization[T]) { + val rta = rt(a) // before we do anything ensure these don't throw + val rtb = rt(b) // before we do anything ensure these don't throw + assert(oBufCompare(rta, a) === 0, s"A should be equal to itself after an RT -- ${rt(a)}") + assert(oBufCompare(rtb, b) === 0, "B should be equal to itself after an RT-- ${rt(b)}") + assert(oBufCompare(a, b) === 0, "In memory comparasons make sense") + assert(oBufCompare(b, a) === 0, "In memory comparasons make sense") + assert(rawCompare(a, b) === 0, "When adding the raw compares in inverse order they should sum to 0") + assert(rawCompare(b, a) === 0, "When adding the raw compares in inverse order they should sum to 0") + assert(oBufCompare(rta, rtb) === 0, "Comparing a and b with ordered bufferables compare after a serialization RT") + } + + def check[T: Arbitrary](implicit obuf: OrderedSerialization[T]) = { + Checkers.check(LawTester(OrderedSerialization.allLaws)) + forAll(minSuccessful(500)) { (a: T, b: T) => checkWithInputs(a, b) } + } + + test("Test out Unit") { + primitiveOrderedBufferSupplier[Unit] + check[Unit] + checkMany[Unit] + } + test("Test out Boolean") { + primitiveOrderedBufferSupplier[Boolean] + check[Boolean] + } + test("Test out jl.Boolean") { + implicit val a = arbMap { b: Boolean => java.lang.Boolean.valueOf(b) } + check[java.lang.Boolean] + } + test("Test out Byte") { check[Byte] } + test("Test out jl.Byte") { + implicit val a = arbMap { b: Byte => java.lang.Byte.valueOf(b) } + check[java.lang.Byte] + } + test("Test out Short") { check[Short] } + test("Test out jl.Short") { + implicit val a = arbMap { b: Short => java.lang.Short.valueOf(b) } + check[java.lang.Short] + } + test("Test out Char") { check[Char] } + test("Test out jl.Char") { + implicit val a = arbMap { b: Char => java.lang.Character.valueOf(b) } + check[java.lang.Character] + } + test("Test out Int") { + primitiveOrderedBufferSupplier[Int] + check[Int] + checkMany[Int] + } + test("Test out jl.Integer") { + implicit val a = arbMap { b: Int => java.lang.Integer.valueOf(b) } + check[java.lang.Integer] + } + test("Test out Float") { check[Float] } + test("Test out jl.Float") { + implicit val a = arbMap { b: Float => java.lang.Float.valueOf(b) } + check[java.lang.Float] + } + test("Test out Long") { check[Long] } + test("Test out jl.Long") { + implicit val a = arbMap { b: Long => java.lang.Long.valueOf(b) } + check[java.lang.Long] + } + test("Test out Double") { check[Double] } + test("Test out jl.Double") { + implicit val a = arbMap { b: Double => java.lang.Double.valueOf(b) } + check[java.lang.Double] + } + + test("Test out String") { + primitiveOrderedBufferSupplier[String] + + check[String] + checkMany[String] + } + + test("Test out ByteBuffer") { + primitiveOrderedBufferSupplier[ByteBuffer] + check[ByteBuffer] + } + + test("Test out List[Float]") { + primitiveOrderedBufferSupplier[List[Float]] + check[List[Float]] + } + test("Test out Queue[Int]") { + implicit val isa = collectionArb[Queue, Int] + primitiveOrderedBufferSupplier[Queue[Int]] + check[Queue[Int]] + } + test("Test out IndexedSeq[Int]") { + implicit val isa = collectionArb[IndexedSeq, Int] + primitiveOrderedBufferSupplier[IndexedSeq[Int]] + check[IndexedSeq[Int]] + } + test("Test out HashSet[Int]") { + import scala.collection.immutable.HashSet + implicit val isa = collectionArb[HashSet, Int] + primitiveOrderedBufferSupplier[HashSet[Int]] + check[HashSet[Int]] + } + test("Test out ListSet[Int]") { + import scala.collection.immutable.ListSet + implicit val isa = collectionArb[ListSet, Int] + primitiveOrderedBufferSupplier[ListSet[Int]] + check[ListSet[Int]] + } + + test("Test out List[String]") { + primitiveOrderedBufferSupplier[List[String]] + check[List[String]] + } + + test("Test out List[List[String]]") { + val oBuf = primitiveOrderedBufferSupplier[List[List[String]]] + assert(oBuf.dynamicSize(List(List("sdf"))) === None) + check[List[List[String]]] + } + + test("Test out List[Int]") { + primitiveOrderedBufferSupplier[List[Int]] + check[List[Int]] + } + + test("Test out SetAlias") { + primitiveOrderedBufferSupplier[SetAlias] + check[SetAlias] + } + + test("Container.InnerCaseClass") { + primitiveOrderedBufferSupplier[Container.InnerCaseClass] + check[Container.InnerCaseClass] + } + + test("Test out Seq[Int]") { + primitiveOrderedBufferSupplier[Seq[Int]] + check[Seq[Int]] + } + test("Test out scala.collection.Seq[Int]") { + primitiveOrderedBufferSupplier[scala.collection.Seq[Int]] + check[scala.collection.Seq[Int]] + } + + test("Test out Array[Byte]") { + primitiveOrderedBufferSupplier[Array[Byte]] + check[Array[Byte]] + } + + test("Test out Vector[Int]") { + primitiveOrderedBufferSupplier[Vector[Int]] + check[Vector[Int]] + } + + test("Test out Iterable[Int]") { + primitiveOrderedBufferSupplier[Iterable[Int]] + check[Iterable[Int]] + } + + test("Test out Set[Int]") { + primitiveOrderedBufferSupplier[Set[Int]] + check[Set[Int]] + } + + test("Test out Set[Double]") { + primitiveOrderedBufferSupplier[Set[Double]] + check[Set[Double]] + } + + test("Test out Map[Long, Set[Int]]") { + primitiveOrderedBufferSupplier[Map[Long, Set[Int]]] + check[Map[Long, Set[Int]]] + val c = List(Map(9223372036854775807L -> Set[Int]()), Map(-1L -> Set[Int](-2043106012))) + checkManyExplicit(c.map { i => (i, i) }) + checkMany[Map[Long, Set[Int]]] + } + + test("Test out Map[Long, Long]") { + primitiveOrderedBufferSupplier[Map[Long, Long]] + check[Map[Long, Long]] + } + test("Test out HashMap[Long, Long]") { + import scala.collection.immutable.HashMap + implicit val isa = Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(HashMap(_: _*))) + primitiveOrderedBufferSupplier[HashMap[Long, Long]] + check[HashMap[Long, Long]] + } + test("Test out ListMap[Long, Long]") { + import scala.collection.immutable.ListMap + implicit val isa = Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(ListMap(_: _*))) + primitiveOrderedBufferSupplier[ListMap[Long, Long]] + check[ListMap[Long, Long]] + } + + test("Test out comparing Maps(3->2, 2->3) and Maps(2->3, 3->2) ") { + val a = Map(3 -> 2, 2 -> 3) + val b = Map(2 -> 3, 3 -> 2) + checkWithInputs(a, b) + checkAreSame(a, b) + } + + test("Test out comparing Set(\"asdf\", \"jkl\") and Set(\"jkl\", \"asdf\")") { + val a = Set("asdf", "jkl") + val b = Set("jkl", "asdf") + checkWithInputs(a, b) + checkAreSame(a, b) + } + + test("Test known hard String Case") { + val a = "6" + val b = "곆" + val ord = Ordering.String + assert(rawCompare(a, b) === ord.compare(a, b).signum, "Raw and in memory compares match.") + + val c = List("榴㉕⊟풠湜ᙬ覹ꜻ裧뚐⠂覝쫨塢䇺楠谭픚ᐌ轮뺷Ⱟ洦擄黏著탅ﮓꆋ숷梸傠ァ蹵窥轲闇涡飽ꌳ䝞慙擃", + "堒凳媨쉏떽㶥⾽샣井ㆠᇗ裉깴辫࠷᤭塈䎙寫㸉ᶴ䰄똇䡷䥞㷗䷱赫懓䷏剆祲ᝯ졑쐯헢鷴ӕ秔㽰ퟡ㏉鶖奚㙰银䮌ᕗ膾买씋썴행䣈丶偝쾕鐗쇊ኋ넥︇瞤䋗噯邧⹆♣ἷ铆玼⪷沕辤ᠥ⥰箼䔄◗", + "騰쓢堷뛭ᣣﰩ嚲ﲯ㤑ᐜ檊೦⠩奯ᓩ윇롇러ᕰెꡩ璞﫼᭵礀閮䈦椄뾪ɔ믻䖔᪆嬽フ鶬曭꣍ᆏ灖㐸뗋ㆃ녵ퟸ겵晬礙㇩䫓ᘞ昑싨", + "좃ఱ䨻綛糔唄࿁劸酊᫵橻쩳괊筆ݓ淤숪輡斋靑耜঄骐冠㝑⧠떅漫곡祈䵾ᳺ줵됵↲搸虂㔢Ꝅ芆٠풐쮋炞哙⨗쾄톄멛癔짍避쇜畾㣕剼⫁়╢ꅢ澛氌ᄚ㍠ꃫᛔ匙㜗詇閦單錖⒅瘧崥", + "獌癚畇") + checkManyExplicit(c.map { i => (i, i) }) + + val c2 = List("聸", "") + checkManyExplicit(c2.map { i => (i, i) }) + } + + test("Test out Option[Int]") { + val oser = primitiveOrderedBufferSupplier[Option[Int]] + + assert(oser.staticSize === None, "can't get the size statically") + check[Option[Int]] + checkMany[Option[Int]] + } + + test("Test out Option[String]") { + primitiveOrderedBufferSupplier[Option[String]] + + check[Option[String]] + checkMany[Option[String]] + } + + test("Test Either[Int, Option[Int]]") { + val oser = primitiveOrderedBufferSupplier[Either[Int, Option[Int]]] + assert(oser.staticSize === None, "can't get the size statically") + check[Either[Int, Option[Int]]] + } + test("Test Either[Int, String]") { + val oser = primitiveOrderedBufferSupplier[Either[Int, String]] + assert(oser.staticSize === None, "can't get the size statically") + assert(Some(Serialization.toBytes[Either[Int, String]](Left(1)).length) === oser.dynamicSize(Left(1)), + "serialization size matches dynamic size") + check[Either[Int, String]] + } + test("Test Either[Int, Int]") { + val oser = primitiveOrderedBufferSupplier[Either[Int, Int]] + assert(oser.staticSize === Some(5), "can get the size statically") + check[Either[Int, Int]] + } + test("Test Either[String, Int]") { + primitiveOrderedBufferSupplier[Either[String, Int]] + check[Either[String, Int]] + } + test("Test Either[String, String]") { + primitiveOrderedBufferSupplier[Either[String, String]] + check[Either[String, String]] + } + + test("Test out Option[Option[Int]]") { + primitiveOrderedBufferSupplier[Option[Option[Int]]] + + check[Option[Option[Int]]] + } + + test("test product like TestCC") { + checkMany[(Int, Char, Long, Option[Int], Double, Option[String])] + } + + test("test specific tuple aa1") { + primitiveOrderedBufferSupplier[(String, Option[Int], String)] + + checkMany[(String, Option[Int], String)] + } + + test("test specific tuple 2") { + check[(String, Option[Int], String)] + } + + test("test specific tuple 3") { + val c = List(("", None, ""), + ("a", Some(1), "b")) + checkManyExplicit(c.map { i => (i, i) }) + } + + test("Test out TestCC") { + import TestCC._ + primitiveOrderedBufferSupplier[TestCC] + check[TestCC] + checkMany[TestCC] + } + + test("Test out (Int, Int)") { + primitiveOrderedBufferSupplier[(Int, Int)] + check[(Int, Int)] + } + + test("Test out (String, Option[Int], String)") { + primitiveOrderedBufferSupplier[(String, Option[Int], String)] + check[(String, Option[Int], String)] + } + + test("Test out MyData") { + import MyData._ + primitiveOrderedBufferSupplier[MyData] + check[MyData] + } + + test("Test out MacroOpaqueContainer") { + // This will test for things which our macros can't view themselves, so need to use an implicit to let the user provide instead. + // by itself should just work from its own implicits + implicitly[OrderedSerialization[MacroOpaqueContainer]] + + // Put inside a tuple2 to test that + primitiveOrderedBufferSupplier[(MacroOpaqueContainer, MacroOpaqueContainer)] + check[(MacroOpaqueContainer, MacroOpaqueContainer)] + check[Option[MacroOpaqueContainer]] + check[List[MacroOpaqueContainer]] + } +} diff --git a/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/TraversableHelperLaws.scala b/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/TraversableHelperLaws.scala new file mode 100644 index 0000000000..922c74f51f --- /dev/null +++ b/scalding-serialization-macros/src/test/scala/com/twitter/scalding/serialization/macros/TraversableHelperLaws.scala @@ -0,0 +1,50 @@ +/* +Copyright 2014 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization.macros + +import org.scalatest.{ FunSuite, ShouldMatchers } +import org.scalatest.prop.Checkers +import org.scalatest.prop.PropertyChecks + +import impl.ordered_serialization.runtime_helpers.TraversableHelpers._ + +class TraversableHelperLaws extends FunSuite with PropertyChecks with ShouldMatchers { + test("Iterator ordering should be Iterable ordering") { + forAll { (l1: List[Int], l2: List[Int]) => + assert(iteratorCompare[Int](l1.iterator, l2.iterator) === + Ordering[Iterable[Int]].compare(l1, l2), "Matches scala's Iterable compare") + } + } + test("Iterator equiv should be Iterable ordering") { + forAll { (l1: List[Int], l2: List[Int]) => + assert(iteratorEquiv[Int](l1.iterator, l2.iterator) === + Ordering[Iterable[Int]].equiv(l1, l2), "Matches scala's Iterable compare") + } + } + test("sortedCompare matches sort followed by compare List[Int]") { + forAll(minSuccessful(1000)) { (l1: List[Int], l2: List[Int]) => + assert(sortedCompare[Int](l1, l2) === + Ordering[Iterable[Int]].compare(l1.sorted, l2.sorted), "Matches scala's Iterable compare") + } + } + test("sortedCompare matches sort followed by compare Set[Int]") { + forAll(minSuccessful(1000)) { (l1: Set[Int], l2: Set[Int]) => + assert(sortedCompare[Int](l1, l2) === + Ordering[Iterable[Int]].compare(l1.toList.sorted, l2.toList.sorted), "Matches scala's Iterable compare") + } + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala new file mode 100644 index 0000000000..5e57a3427e --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala @@ -0,0 +1,357 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import java.util.concurrent.atomic.AtomicReference +import java.io.{ InputStream, OutputStream } + +/** + * This interface is a way of wrapping a value in a marker class + * whose class identity is used to control which serialization we + * use. This is an internal implementation detail about how we + * interact with cascading and hadoop. Users should never care. + */ +trait Boxed[+K] { + def get: K +} + +class Boxed0[K](override val get: K) extends Boxed[K] + +class Boxed1[K](override val get: K) extends Boxed[K] + +class Boxed2[K](override val get: K) extends Boxed[K] + +class Boxed3[K](override val get: K) extends Boxed[K] + +class Boxed4[K](override val get: K) extends Boxed[K] + +class Boxed5[K](override val get: K) extends Boxed[K] + +class Boxed6[K](override val get: K) extends Boxed[K] + +class Boxed7[K](override val get: K) extends Boxed[K] + +class Boxed8[K](override val get: K) extends Boxed[K] + +class Boxed9[K](override val get: K) extends Boxed[K] + +class Boxed10[K](override val get: K) extends Boxed[K] + +class Boxed11[K](override val get: K) extends Boxed[K] + +class Boxed12[K](override val get: K) extends Boxed[K] + +class Boxed13[K](override val get: K) extends Boxed[K] + +class Boxed14[K](override val get: K) extends Boxed[K] + +class Boxed15[K](override val get: K) extends Boxed[K] + +class Boxed16[K](override val get: K) extends Boxed[K] + +class Boxed17[K](override val get: K) extends Boxed[K] + +class Boxed18[K](override val get: K) extends Boxed[K] + +class Boxed19[K](override val get: K) extends Boxed[K] + +class Boxed20[K](override val get: K) extends Boxed[K] + +class Boxed21[K](override val get: K) extends Boxed[K] + +class Boxed22[K](override val get: K) extends Boxed[K] + +class Boxed23[K](override val get: K) extends Boxed[K] + +class Boxed24[K](override val get: K) extends Boxed[K] + +class Boxed25[K](override val get: K) extends Boxed[K] + +class Boxed26[K](override val get: K) extends Boxed[K] + +class Boxed27[K](override val get: K) extends Boxed[K] + +class Boxed28[K](override val get: K) extends Boxed[K] + +class Boxed29[K](override val get: K) extends Boxed[K] + +class Boxed30[K](override val get: K) extends Boxed[K] + +class Boxed31[K](override val get: K) extends Boxed[K] + +class Boxed32[K](override val get: K) extends Boxed[K] + +class Boxed33[K](override val get: K) extends Boxed[K] + +class Boxed34[K](override val get: K) extends Boxed[K] + +class Boxed35[K](override val get: K) extends Boxed[K] + +class Boxed36[K](override val get: K) extends Boxed[K] + +class Boxed37[K](override val get: K) extends Boxed[K] + +class Boxed38[K](override val get: K) extends Boxed[K] + +class Boxed39[K](override val get: K) extends Boxed[K] + +class Boxed40[K](override val get: K) extends Boxed[K] + +class Boxed41[K](override val get: K) extends Boxed[K] + +class Boxed42[K](override val get: K) extends Boxed[K] + +class Boxed43[K](override val get: K) extends Boxed[K] + +class Boxed44[K](override val get: K) extends Boxed[K] + +class Boxed45[K](override val get: K) extends Boxed[K] + +class Boxed46[K](override val get: K) extends Boxed[K] + +class Boxed47[K](override val get: K) extends Boxed[K] + +class Boxed48[K](override val get: K) extends Boxed[K] + +class Boxed49[K](override val get: K) extends Boxed[K] + +class Boxed50[K](override val get: K) extends Boxed[K] + +class Boxed51[K](override val get: K) extends Boxed[K] + +class Boxed52[K](override val get: K) extends Boxed[K] + +class Boxed53[K](override val get: K) extends Boxed[K] + +class Boxed54[K](override val get: K) extends Boxed[K] + +class Boxed55[K](override val get: K) extends Boxed[K] + +class Boxed56[K](override val get: K) extends Boxed[K] + +class Boxed57[K](override val get: K) extends Boxed[K] + +class Boxed58[K](override val get: K) extends Boxed[K] + +class Boxed59[K](override val get: K) extends Boxed[K] + +class Boxed60[K](override val get: K) extends Boxed[K] + +class Boxed61[K](override val get: K) extends Boxed[K] + +class Boxed62[K](override val get: K) extends Boxed[K] + +class Boxed63[K](override val get: K) extends Boxed[K] + +class Boxed64[K](override val get: K) extends Boxed[K] + +class Boxed65[K](override val get: K) extends Boxed[K] + +class Boxed66[K](override val get: K) extends Boxed[K] + +class Boxed67[K](override val get: K) extends Boxed[K] + +class Boxed68[K](override val get: K) extends Boxed[K] + +class Boxed69[K](override val get: K) extends Boxed[K] + +class Boxed70[K](override val get: K) extends Boxed[K] + +class Boxed71[K](override val get: K) extends Boxed[K] + +class Boxed72[K](override val get: K) extends Boxed[K] + +class Boxed73[K](override val get: K) extends Boxed[K] + +class Boxed74[K](override val get: K) extends Boxed[K] + +class Boxed75[K](override val get: K) extends Boxed[K] + +class Boxed76[K](override val get: K) extends Boxed[K] + +class Boxed77[K](override val get: K) extends Boxed[K] + +class Boxed78[K](override val get: K) extends Boxed[K] + +class Boxed79[K](override val get: K) extends Boxed[K] + +class Boxed80[K](override val get: K) extends Boxed[K] + +class Boxed81[K](override val get: K) extends Boxed[K] + +class Boxed82[K](override val get: K) extends Boxed[K] + +class Boxed83[K](override val get: K) extends Boxed[K] + +class Boxed84[K](override val get: K) extends Boxed[K] + +class Boxed85[K](override val get: K) extends Boxed[K] + +class Boxed86[K](override val get: K) extends Boxed[K] + +class Boxed87[K](override val get: K) extends Boxed[K] + +class Boxed88[K](override val get: K) extends Boxed[K] + +class Boxed89[K](override val get: K) extends Boxed[K] + +class Boxed90[K](override val get: K) extends Boxed[K] + +class Boxed91[K](override val get: K) extends Boxed[K] + +class Boxed92[K](override val get: K) extends Boxed[K] + +class Boxed93[K](override val get: K) extends Boxed[K] + +class Boxed94[K](override val get: K) extends Boxed[K] + +class Boxed95[K](override val get: K) extends Boxed[K] + +class Boxed96[K](override val get: K) extends Boxed[K] + +class Boxed97[K](override val get: K) extends Boxed[K] + +class Boxed98[K](override val get: K) extends Boxed[K] + +class Boxed99[K](override val get: K) extends Boxed[K] + +case class BoxedOrderedSerialization[K](box: K => Boxed[K], + ord: OrderedSerialization[K]) extends OrderedSerialization[Boxed[K]] { + + override def compare(a: Boxed[K], b: Boxed[K]) = ord.compare(a.get, b.get) + override def hash(k: Boxed[K]) = ord.hash(k.get) + override def compareBinary(a: InputStream, b: InputStream) = ord.compareBinary(a, b) + override def read(from: InputStream) = ord.read(from).map(box) + override def write(into: OutputStream, bk: Boxed[K]) = ord.write(into, bk.get) + override def staticSize = ord.staticSize + override def dynamicSize(k: Boxed[K]) = ord.dynamicSize(k.get) +} + +object Boxed { + private[this] val allBoxes = List( + ({ t: Any => new Boxed0(t) }, classOf[Boxed0[Any]]), + ({ t: Any => new Boxed1(t) }, classOf[Boxed1[Any]]), + ({ t: Any => new Boxed2(t) }, classOf[Boxed2[Any]]), + ({ t: Any => new Boxed3(t) }, classOf[Boxed3[Any]]), + ({ t: Any => new Boxed4(t) }, classOf[Boxed4[Any]]), + ({ t: Any => new Boxed5(t) }, classOf[Boxed5[Any]]), + ({ t: Any => new Boxed6(t) }, classOf[Boxed6[Any]]), + ({ t: Any => new Boxed7(t) }, classOf[Boxed7[Any]]), + ({ t: Any => new Boxed8(t) }, classOf[Boxed8[Any]]), + ({ t: Any => new Boxed9(t) }, classOf[Boxed9[Any]]), + ({ t: Any => new Boxed10(t) }, classOf[Boxed10[Any]]), + ({ t: Any => new Boxed11(t) }, classOf[Boxed11[Any]]), + ({ t: Any => new Boxed12(t) }, classOf[Boxed12[Any]]), + ({ t: Any => new Boxed13(t) }, classOf[Boxed13[Any]]), + ({ t: Any => new Boxed14(t) }, classOf[Boxed14[Any]]), + ({ t: Any => new Boxed15(t) }, classOf[Boxed15[Any]]), + ({ t: Any => new Boxed16(t) }, classOf[Boxed16[Any]]), + ({ t: Any => new Boxed17(t) }, classOf[Boxed17[Any]]), + ({ t: Any => new Boxed18(t) }, classOf[Boxed18[Any]]), + ({ t: Any => new Boxed19(t) }, classOf[Boxed19[Any]]), + ({ t: Any => new Boxed20(t) }, classOf[Boxed20[Any]]), + ({ t: Any => new Boxed21(t) }, classOf[Boxed21[Any]]), + ({ t: Any => new Boxed22(t) }, classOf[Boxed22[Any]]), + ({ t: Any => new Boxed23(t) }, classOf[Boxed23[Any]]), + ({ t: Any => new Boxed24(t) }, classOf[Boxed24[Any]]), + ({ t: Any => new Boxed25(t) }, classOf[Boxed25[Any]]), + ({ t: Any => new Boxed26(t) }, classOf[Boxed26[Any]]), + ({ t: Any => new Boxed27(t) }, classOf[Boxed27[Any]]), + ({ t: Any => new Boxed28(t) }, classOf[Boxed28[Any]]), + ({ t: Any => new Boxed29(t) }, classOf[Boxed29[Any]]), + ({ t: Any => new Boxed30(t) }, classOf[Boxed30[Any]]), + ({ t: Any => new Boxed31(t) }, classOf[Boxed31[Any]]), + ({ t: Any => new Boxed32(t) }, classOf[Boxed32[Any]]), + ({ t: Any => new Boxed33(t) }, classOf[Boxed33[Any]]), + ({ t: Any => new Boxed34(t) }, classOf[Boxed34[Any]]), + ({ t: Any => new Boxed35(t) }, classOf[Boxed35[Any]]), + ({ t: Any => new Boxed36(t) }, classOf[Boxed36[Any]]), + ({ t: Any => new Boxed37(t) }, classOf[Boxed37[Any]]), + ({ t: Any => new Boxed38(t) }, classOf[Boxed38[Any]]), + ({ t: Any => new Boxed39(t) }, classOf[Boxed39[Any]]), + ({ t: Any => new Boxed40(t) }, classOf[Boxed40[Any]]), + ({ t: Any => new Boxed41(t) }, classOf[Boxed41[Any]]), + ({ t: Any => new Boxed42(t) }, classOf[Boxed42[Any]]), + ({ t: Any => new Boxed43(t) }, classOf[Boxed43[Any]]), + ({ t: Any => new Boxed44(t) }, classOf[Boxed44[Any]]), + ({ t: Any => new Boxed45(t) }, classOf[Boxed45[Any]]), + ({ t: Any => new Boxed46(t) }, classOf[Boxed46[Any]]), + ({ t: Any => new Boxed47(t) }, classOf[Boxed47[Any]]), + ({ t: Any => new Boxed48(t) }, classOf[Boxed48[Any]]), + ({ t: Any => new Boxed49(t) }, classOf[Boxed49[Any]]), + ({ t: Any => new Boxed50(t) }, classOf[Boxed50[Any]]), + ({ t: Any => new Boxed51(t) }, classOf[Boxed51[Any]]), + ({ t: Any => new Boxed52(t) }, classOf[Boxed52[Any]]), + ({ t: Any => new Boxed53(t) }, classOf[Boxed53[Any]]), + ({ t: Any => new Boxed54(t) }, classOf[Boxed54[Any]]), + ({ t: Any => new Boxed55(t) }, classOf[Boxed55[Any]]), + ({ t: Any => new Boxed56(t) }, classOf[Boxed56[Any]]), + ({ t: Any => new Boxed57(t) }, classOf[Boxed57[Any]]), + ({ t: Any => new Boxed58(t) }, classOf[Boxed58[Any]]), + ({ t: Any => new Boxed59(t) }, classOf[Boxed59[Any]]), + ({ t: Any => new Boxed60(t) }, classOf[Boxed60[Any]]), + ({ t: Any => new Boxed61(t) }, classOf[Boxed61[Any]]), + ({ t: Any => new Boxed62(t) }, classOf[Boxed62[Any]]), + ({ t: Any => new Boxed63(t) }, classOf[Boxed63[Any]]), + ({ t: Any => new Boxed64(t) }, classOf[Boxed64[Any]]), + ({ t: Any => new Boxed65(t) }, classOf[Boxed65[Any]]), + ({ t: Any => new Boxed66(t) }, classOf[Boxed66[Any]]), + ({ t: Any => new Boxed67(t) }, classOf[Boxed67[Any]]), + ({ t: Any => new Boxed68(t) }, classOf[Boxed68[Any]]), + ({ t: Any => new Boxed69(t) }, classOf[Boxed69[Any]]), + ({ t: Any => new Boxed70(t) }, classOf[Boxed70[Any]]), + ({ t: Any => new Boxed71(t) }, classOf[Boxed71[Any]]), + ({ t: Any => new Boxed72(t) }, classOf[Boxed72[Any]]), + ({ t: Any => new Boxed73(t) }, classOf[Boxed73[Any]]), + ({ t: Any => new Boxed74(t) }, classOf[Boxed74[Any]]), + ({ t: Any => new Boxed75(t) }, classOf[Boxed75[Any]]), + ({ t: Any => new Boxed76(t) }, classOf[Boxed76[Any]]), + ({ t: Any => new Boxed77(t) }, classOf[Boxed77[Any]]), + ({ t: Any => new Boxed78(t) }, classOf[Boxed78[Any]]), + ({ t: Any => new Boxed79(t) }, classOf[Boxed79[Any]]), + ({ t: Any => new Boxed80(t) }, classOf[Boxed80[Any]]), + ({ t: Any => new Boxed81(t) }, classOf[Boxed81[Any]]), + ({ t: Any => new Boxed82(t) }, classOf[Boxed82[Any]]), + ({ t: Any => new Boxed83(t) }, classOf[Boxed83[Any]]), + ({ t: Any => new Boxed84(t) }, classOf[Boxed84[Any]]), + ({ t: Any => new Boxed85(t) }, classOf[Boxed85[Any]]), + ({ t: Any => new Boxed86(t) }, classOf[Boxed86[Any]]), + ({ t: Any => new Boxed87(t) }, classOf[Boxed87[Any]]), + ({ t: Any => new Boxed88(t) }, classOf[Boxed88[Any]]), + ({ t: Any => new Boxed89(t) }, classOf[Boxed89[Any]]), + ({ t: Any => new Boxed90(t) }, classOf[Boxed90[Any]]), + ({ t: Any => new Boxed91(t) }, classOf[Boxed91[Any]]), + ({ t: Any => new Boxed92(t) }, classOf[Boxed92[Any]]), + ({ t: Any => new Boxed93(t) }, classOf[Boxed93[Any]]), + ({ t: Any => new Boxed94(t) }, classOf[Boxed94[Any]]), + ({ t: Any => new Boxed95(t) }, classOf[Boxed95[Any]]), + ({ t: Any => new Boxed96(t) }, classOf[Boxed96[Any]]), + ({ t: Any => new Boxed97(t) }, classOf[Boxed97[Any]]), + ({ t: Any => new Boxed98(t) }, classOf[Boxed98[Any]]), + ({ t: Any => new Boxed99(t) }, classOf[Boxed99[Any]])) + + private[this] val boxes: AtomicReference[List[(Any => Boxed[Any], Class[_ <: Boxed[Any]])]] = + new AtomicReference(allBoxes) + + def allClasses: Seq[Class[_ <: Boxed[_]]] = allBoxes.map(_._2) + + def next[K]: (K => Boxed[K], Class[Boxed[K]]) = boxes.get match { + case list @ (h :: tail) if boxes.compareAndSet(list, tail) => + h.asInstanceOf[(K => Boxed[K], Class[Boxed[K]])] + case (h :: tail) => next[K] // Try again + case Nil => sys.error("Exhausted the boxed classes") + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Hasher.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Hasher.scala new file mode 100644 index 0000000000..52600598d5 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Hasher.scala @@ -0,0 +1,97 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +// Be careful using this, the product/array or similar will attempt to call system hash codes. +import scala.util.hashing.MurmurHash3 +/** + * This is a specialized typeclass to make it easier to implement Serializations. + * The specialization *should* mean that there is no boxing and if the JIT + * does its work, Hasher should compose well (via collections, Tuple2, Option, Either) + */ +trait Hasher[@specialized(Boolean, Byte, Char, Short, Int, Long, Float, Double) -T] { + @inline + def hash(i: T): Int +} + +object Hasher { + import MurmurHashUtils._ + final val seed = 0xf7ca7fd2 + + @inline + def hash[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T]( + i: T)(implicit h: Hasher[T]): Int = h.hash(i) + + /* + * Instances below + */ + implicit val unit: Hasher[Unit] = new Hasher[Unit] { + @inline + def hash(i: Unit) = 0 + } + implicit val boolean: Hasher[Boolean] = new Hasher[Boolean] { + /** + * Here we use the two large primes as the hash codes. + * We use primes because we want the probability of collision when + * we mod with some size (to fit into hash-buckets stored in an array) + * to be low. The choice of prime numbers means that they have no factors + * in common with any size, but they could have the same remainder. + * We actually just use the exact same values as Java here. + */ + @inline + def hash(i: Boolean) = if (i) 1231 else 1237 + } + implicit val byte: Hasher[Byte] = new Hasher[Byte] { + @inline + def hash(i: Byte) = hashInt(i.toInt) + } + implicit val char: Hasher[Char] = new Hasher[Char] { + @inline + def hash(i: Char) = hashInt(i.toInt) + } + val character = char + + implicit val short: Hasher[Short] = new Hasher[Short] { + @inline + def hash(i: Short) = hashInt(i.toInt) + } + + implicit val int: Hasher[Int] = new Hasher[Int] { + @inline + def hash(i: Int) = hashInt(i) + } + + // java way to refer to int, alias in naming + val integer = int + + implicit val long: Hasher[Long] = new Hasher[Long] { + @inline + def hash(i: Long) = hashLong(i) + } + + implicit val float: Hasher[Float] = new Hasher[Float] { + @inline + def hash(i: Float) = hashInt(java.lang.Float.floatToIntBits(i)) + } + implicit val double: Hasher[Double] = new Hasher[Double] { + @inline + def hash(i: Double) = hashLong(java.lang.Double.doubleToLongBits(i)) + } + implicit val string: Hasher[String] = new Hasher[String] { + @inline + def hash(i: String) = MurmurHash3.stringHash(i) + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala new file mode 100644 index 0000000000..528807f824 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala @@ -0,0 +1,283 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import java.io._ + +object JavaStreamEnrichments { + def eof: Nothing = throw new EOFException() + + // We use this to avoid allocating a closure to make + // a lazy parameter to require + private def illegal(s: String): Nothing = + throw new IllegalArgumentException(s) + + /** + * Note this is only recommended for testing. + * You may want to use ByteArrayInputOutputStream for performance critical concerns + */ + implicit class RichByteArrayOutputStream(val baos: ByteArrayOutputStream) extends AnyVal { + def toInputStream: ByteArrayInputStream = new ByteArrayInputStream(baos.toByteArray) + } + + /** + * enrichment to treat an Array like an OutputStream + */ + implicit class RichByteArray(val bytes: Array[Byte]) extends AnyVal { + def wrapAsOutputStream: ArrayWrappingOutputStream = wrapAsOutputStreamAt(0) + def wrapAsOutputStreamAt(pos: Int): ArrayWrappingOutputStream = + new ArrayWrappingOutputStream(bytes, pos) + } + /** + * Wraps an Array so that you can write into it as a stream without reallocations + * or copying at the end. Useful if you know an upper bound on the number of bytes + * you will write + */ + class ArrayWrappingOutputStream(val buffer: Array[Byte], initPos: Int) extends OutputStream { + if (buffer.length < initPos) { + illegal(s"Initial position cannot be more than length: $initPos > ${buffer.length}") + } + private[this] var pos = initPos + def position: Int = pos + override def write(b: Int) { buffer(pos) = b.toByte; pos += 1 } + override def write(b: Array[Byte], off: Int, len: Int) { + Array.copy(b, off, buffer, pos, len) + pos += len + } + } + + def posVarIntSize(i: Int): Int = { + if (i < 0) illegal(s"negative numbers not allowed: $i") + if (i < ((1 << 8) - 1)) 1 + else { + if (i < ((1 << 16) - 1)) { + 3 + } else { + 7 + } + } + } + + /** + * This has a lot of methods from DataInputStream without + * having to allocate to get them + * This code is similar to those algorithms + */ + implicit class RichInputStream(val s: InputStream) extends AnyVal { + /** + * If s supports marking, we mark it. Otherwise we read the needed + * bytes out into a ByteArrayStream and return that. + * This is intended for the case where you need possibly + * read size bytes but may stop early, then skip this exact + * number of bytes. + * Intended use is: + * {code} + * val size = 100 + * val marked = s.markOrBuffer(size) + * val y = fn(marked) + * marked.reset + * marked.skipFully(size) + * {/code} + */ + def markOrBuffer(size: Int): InputStream = { + val ms = if (s.markSupported) s else { + val buf = new Array[Byte](size) + s.readFully(buf) + new ByteArrayInputStream(buf) + } + // Make sure we can reset after we read this many bytes + ms.mark(size) + ms + } + + def readBoolean: Boolean = (readUnsignedByte != 0) + + /** + * Like read, but throws eof on error + */ + def readByte: Byte = readUnsignedByte.toByte + + def readUnsignedByte: Int = { + // Note that Java, when you read a byte, returns a Int holding an unsigned byte. + // if the value is < 0, you hit EOF. + val c1 = s.read + if (c1 < 0) eof else c1 + } + def readUnsignedShort: Int = { + val c1 = s.read + val c2 = s.read + if ((c1 | c2) < 0) eof else ((c1 << 8) | c2) + } + + final def readFully(bytes: Array[Byte]): Unit = readFully(bytes, 0, bytes.length) + + final def readFully(bytes: Array[Byte], offset: Int, len: Int): Unit = { + if (len < 0) throw new IndexOutOfBoundsException() + + @annotation.tailrec + def go(o: Int, l: Int): Unit = + if (l == 0) () + else { + val count = s.read(bytes, o, l) + if (count < 0) eof + else go(o + count, l - count) + } + go(offset, len) + } + + def readDouble: Double = java.lang.Double.longBitsToDouble(readLong) + def readFloat: Float = java.lang.Float.intBitsToFloat(readInt) + + /** + * This is the algorithm from DataInputStream + * it was also benchmarked against the approach + * used in readLong and found to be faster + */ + def readInt: Int = { + val c1 = s.read + val c2 = s.read + val c3 = s.read + val c4 = s.read + if ((c1 | c2 | c3 | c4) < 0) eof else ((c1 << 24) | (c2 << 16) | (c3 << 8) | c4) + } + /* + * This is the algorithm from DataInputStream + * it was also benchmarked against the same approach used + * in readInt (buffer-less) and found to be faster. + */ + def readLong: Long = { + val buf = new Array[Byte](8) + readFully(buf) + (buf(0).toLong << 56) + + ((buf(1) & 255).toLong << 48) + + ((buf(2) & 255).toLong << 40) + + ((buf(3) & 255).toLong << 32) + + ((buf(4) & 255).toLong << 24) + + ((buf(5) & 255) << 16) + + ((buf(6) & 255) << 8) + + (buf(7) & 255) + } + + def readChar: Char = { + val c1 = s.read + val c2 = s.read + // This is the algorithm from DataInputStream + if ((c1 | c2) < 0) eof else ((c1 << 8) | c2).toChar + } + + def readShort: Short = { + val c1 = s.read + val c2 = s.read + // This is the algorithm from DataInputStream + if ((c1 | c2) < 0) eof else ((c1 << 8) | c2).toShort + } + + /** + * This reads a varInt encoding that only encodes non-negative + * numbers. It uses: + * 1 byte for values 0 - 255, + * 3 bytes for 256 - 65535, + * 7 bytes for 65536 - Int.MaxValue + */ + final def readPosVarInt: Int = { + val c1 = readUnsignedByte + if (c1 < ((1 << 8) - 1)) c1 + else { + val c2 = readUnsignedShort + if (c2 < ((1 << 16) - 1)) c2 + else readInt + } + } + + final def skipFully(count: Long): Unit = { + @annotation.tailrec + def go(c: Long): Unit = { + val skipped = s.skip(c) + if (skipped == c) () + else if (skipped == 0L) throw new IOException(s"could not skipFully: count, c, skipped = ${(count, c, skipped)}") + else go(c - skipped) + } + if (count != 0L) go(count) else () + } + } + + implicit class RichOutputStream(val s: OutputStream) extends AnyVal { + def writeBoolean(b: Boolean): Unit = if (b) s.write(1: Byte) else s.write(0: Byte) + + def writeBytes(b: Array[Byte], off: Int, len: Int): Unit = { + s.write(b, off, len) + } + + def writeByte(b: Byte): Unit = s.write(b) + + def writeBytes(b: Array[Byte]): Unit = writeBytes(b, 0, b.length) + + /** + * This reads a varInt encoding that only encodes non-negative + * numbers. It uses: + * 1 byte for values 0 - 255, + * 3 bytes for 256 - 65535, + * 7 bytes for 65536 - Int.MaxValue + */ + def writePosVarInt(i: Int): Unit = { + if (i < 0) illegal(s"must be non-negative: ${i}") + if (i < ((1 << 8) - 1)) s.write(i) + else { + s.write(-1: Byte) + if (i < ((1 << 16) - 1)) { + s.write(i >> 8) + s.write(i) + } else { + s.write(-1) + s.write(-1) + writeInt(i) + } + } + } + + def writeDouble(d: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(d)) + + def writeFloat(f: Float): Unit = writeInt(java.lang.Float.floatToIntBits(f)) + + def writeLong(l: Long): Unit = { + s.write((l >>> 56).toInt) + s.write((l >>> 48).toInt) + s.write((l >>> 40).toInt) + s.write((l >>> 32).toInt) + s.write((l >>> 24).toInt) + s.write((l >>> 16).toInt) + s.write((l >>> 8).toInt) + s.write(l.toInt) + } + + def writeInt(i: Int): Unit = { + s.write(i >>> 24) + s.write(i >>> 16) + s.write(i >>> 8) + s.write(i) + } + + def writeChar(sh: Char): Unit = { + s.write(sh >>> 8) + s.write(sh.toInt) + } + + def writeShort(sh: Short): Unit = { + s.write(sh >>> 8) + s.write(sh.toInt) + } + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala new file mode 100644 index 0000000000..a173171480 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala @@ -0,0 +1,27 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +/** + * This is a simple trait for describing laws on single parameter + * type classes (Serialization, Monoid, Ordering, etc...) + */ +sealed trait Law[T] { + def name: String +} +case class Law1[T](override val name: String, check: T => Boolean) extends Law[T] +case class Law2[T](override val name: String, check: (T, T) => Boolean) extends Law[T] +case class Law3[T](override val name: String, check: (T, T, T) => Boolean) extends Law[T] diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/MurmurHashUtils.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/MurmurHashUtils.scala new file mode 100644 index 0000000000..27d4c54ed3 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/MurmurHashUtils.scala @@ -0,0 +1,106 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +// Taking a few functions from: +// https://guava-libraries.googlecode.com/git/guava/src/com/google/common/hash/Murmur3_32HashFunction.java +object MurmurHashUtils { + final val seed = 0xf7ca7fd2 + private final val C1: Int = 0xcc9e2d51 + private final val C2: Int = 0x1b873593 + + final def hashInt(input: Int): Int = { + val k1 = mixK1(input) + val h1 = mixH1(seed, k1) + + fmix(h1, 4) // length of int is 4 bytes + } + + final def hashLong(input: Long): Int = { + val low = input.toInt + val high = (input >>> 32).toInt + + var k1 = mixK1(low) + var h1 = mixH1(seed, k1) + + k1 = mixK1(high) + h1 = mixH1(h1, k1) + + fmix(h1, 8) // 8 bytes + } + + final def hashUnencodedChars(input: CharSequence): Int = { + var h1 = seed; + + // step through the CharSequence 2 chars at a time + var i = 0 + while (i < input.length) { + var k1 = input.charAt(i - 1) | (input.charAt(i) << 16) + k1 = mixK1(k1) + h1 = mixH1(h1, k1) + i += 2 + } + + // deal with any remaining characters + if ((input.length() & 1) == 1) { + var k1: Int = input.charAt(input.length() - 1) + k1 = mixK1(k1) + h1 ^= k1 + } + + fmix(h1, (Character.SIZE / java.lang.Byte.SIZE) * input.length()) + } + + final def mixK1(k1Input: Int): Int = { + var k1 = k1Input + k1 *= C1 + k1 = Integer.rotateLeft(k1, 15) + k1 *= C2 + k1 + } + + final def mixH1(h1Input: Int, k1Input: Int): Int = { + var h1 = h1Input + var k1 = k1Input + h1 ^= k1 + h1 = Integer.rotateLeft(h1, 13) + h1 = h1 * 5 + 0xe6546b64 + h1 + } + + // Finalization mix - force all bits of a hash block to avalanche + final def fmix(h1Input: Int, length: Int): Int = { + var h1 = h1Input + h1 ^= length + h1 ^= h1 >>> 16 + h1 *= 0x85ebca6b + h1 ^= h1 >>> 13 + h1 *= 0xc2b2ae35 + h1 ^= h1 >>> 16 + h1 + } + + def iteratorHash[T](a: Iterator[T])(hashFn: T => Int): Int = { + var h1 = seed + var i = 0 + while (a.hasNext) { + var k1 = hashFn(a.next) + h1 = mixH1(h1, k1) + i += 1 + } + fmix(h1, i) + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala new file mode 100644 index 0000000000..708986a870 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala @@ -0,0 +1,169 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +import java.io.{ ByteArrayInputStream, InputStream, OutputStream } +import scala.util.{ Failure, Success, Try } +import scala.util.control.NonFatal + +/** + * In large-scale partitioning algorithms, we often use sorting. + * This typeclass represents something we can efficiently serialize + * with an added law: that we can (hopefully fast) compare the raw + * data. + */ +trait OrderedSerialization[T] extends Ordering[T] with Serialization[T] { + /** + * This compares two InputStreams. After this call, the position in + * the InputStreams is mutated to be the end of the record. + */ + def compareBinary(a: InputStream, b: InputStream): OrderedSerialization.Result +} + +object OrderedSerialization { + /** + * Represents the result of a comparison that might fail due + * to an error deserializing + */ + sealed trait Result { + /** + * Throws if the items cannot be compared + */ + def unsafeToInt: Int + def toTry: Try[Int] + } + /** + * Create a Result from an Int. + */ + def resultFrom(i: Int): Result = + if (i > 0) Greater + else if (i < 0) Less + else Equal + + def resultFrom(t: Try[Int]): Result = t match { + case Success(i) => resultFrom(i) + case Failure(e) => CompareFailure(e) + } + + final case class CompareFailure(ex: Throwable) extends Result { + def unsafeToInt = throw ex + def toTry = Failure(ex) + } + case object Greater extends Result { + val unsafeToInt = 1 + val toTry = Success(unsafeToInt) + } + case object Equal extends Result { + val unsafeToInt = 0 + val toTry = Success(unsafeToInt) + } + case object Less extends Result { + val unsafeToInt = -1 + val toTry = Success(unsafeToInt) + } + + def compare[T](a: T, b: T)(implicit ord: OrderedSerialization[T]): Int = + ord.compare(a, b) + + def compareBinary[T](a: InputStream, b: InputStream)(implicit ord: OrderedSerialization[T]): Result = + ord.compareBinary(a, b) + + def writeThenCompare[T](a: T, b: T)(implicit ordb: OrderedSerialization[T]): Result = { + val abytes = Serialization.toBytes(a) + val bbytes = Serialization.toBytes(b) + val ain = new ByteArrayInputStream(abytes) + val bin = new ByteArrayInputStream(bbytes) + ordb.compareBinary(ain, bin) + } + + /** + * This is slow, but always an option. Avoid this if you can, especially for large items + */ + def readThenCompare[T: OrderedSerialization](as: InputStream, bs: InputStream): Result = try resultFrom { + val a = Serialization.read[T](as) + val b = Serialization.read[T](bs) + compare(a.get, b.get) + } catch { + case NonFatal(e) => CompareFailure(e) + } + + /** + * The the serialized comparison matches the unserialized comparison + */ + def compareBinaryMatchesCompare[T](implicit ordb: OrderedSerialization[T]): Law2[T] = + Law2("compare(a, b) == compareBinary(aBin, bBin)", + { (a: T, b: T) => resultFrom(ordb.compare(a, b)) == writeThenCompare(a, b) }) + + /** + * ordering must be transitive. If this is not so, sort-based partitioning + * will be broken + */ + def orderingTransitive[T](implicit ordb: OrderedSerialization[T]): Law3[T] = + Law3("transitivity", + { (a: T, b: T, c: T) => + if (ordb.lteq(a, b) && ordb.lteq(b, c)) { ordb.lteq(a, c) } + else true + }) + /** + * ordering must be antisymmetric. If this is not so, sort-based partitioning + * will be broken + */ + def orderingAntisymmetry[T](implicit ordb: OrderedSerialization[T]): Law2[T] = + Law2("antisymmetry", + { (a: T, b: T) => + if (ordb.lteq(a, b) && ordb.lteq(b, a)) { ordb.equiv(a, b) } + else true + }) + /** + * ordering must be total. If this is not so, sort-based partitioning + * will be broken + */ + def orderingTotality[T](implicit ordb: OrderedSerialization[T]): Law2[T] = + Law2("totality", { (a: T, b: T) => (ordb.lteq(a, b) || ordb.lteq(b, a)) }) + + def allLaws[T: OrderedSerialization]: Iterable[Law[T]] = + Serialization.allLaws ++ List(compareBinaryMatchesCompare[T], + orderingTransitive[T], + orderingAntisymmetry[T], + orderingTotality[T]) +} + +/** + * This may be useful when a type is used deep in a tuple or case class, and in that case + * the earlier comparators will have likely already done the work. Be aware that avoiding + * deserialization on compare usually very helpful. + * + * Note: it is your responsibility that the hash in serialization is consistent + * with the ordering (if equivalent in the ordering, the hash must match). + */ +final case class DeserializingOrderedSerialization[T](serialization: Serialization[T], + ordering: Ordering[T]) extends OrderedSerialization[T] { + + final override def read(i: InputStream) = serialization.read(i) + final override def write(o: OutputStream, t: T) = serialization.write(o, t) + final override def hash(t: T) = serialization.hash(t) + final override def compare(a: T, b: T) = ordering.compare(a, b) + final override def compareBinary(a: InputStream, b: InputStream) = + try OrderedSerialization.resultFrom { + compare(read(a).get, read(b).get) + } + catch { + case NonFatal(e) => OrderedSerialization.CompareFailure(e) + } + final override def staticSize = serialization.staticSize + final override def dynamicSize(t: T) = serialization.dynamicSize(t) +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/PositionInputStream.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/PositionInputStream.scala new file mode 100644 index 0000000000..3407f61307 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/PositionInputStream.scala @@ -0,0 +1,88 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import java.io.InputStream +import JavaStreamEnrichments._ + +object PositionInputStream { + def apply(in: InputStream): PositionInputStream = in match { + case p: PositionInputStream => p + case nonPos => new PositionInputStream(nonPos) + } +} + +class PositionInputStream(val wraps: InputStream) extends InputStream { + private[this] var pos: Long = 0L + private[this] var markPos: Long = -1L + def position: Long = pos + + override def available = wraps.available + + override def close() { wraps.close() } + + override def mark(limit: Int) { + wraps.mark(limit) + markPos = pos + } + + override val markSupported: Boolean = wraps.markSupported + + override def read: Int = { + // returns -1 on eof or 0 to 255 store 1 byte. + val result = wraps.read + if (result >= 0) pos += 1 + result + } + override def read(bytes: Array[Byte]): Int = { + val count = wraps.read(bytes) + // Make this branch true as much as possible to improve branch prediction + if (count >= 0) pos += count + count + } + + override def read(bytes: Array[Byte], off: Int, len: Int): Int = { + val count = wraps.read(bytes, off, len) + // Make this branch true as much as possible to improve branch prediction + if (count >= 0) pos += count + count + } + + override def reset() { + wraps.reset() + pos = markPos + } + + private def illegal(s: String): Nothing = + throw new IllegalArgumentException(s) + + override def skip(n: Long): Long = { + if (n < 0) illegal("Must seek fowards") + val count = wraps.skip(n) + // Make this branch true as much as possible to improve branch prediction + if (count >= 0) pos += count + count + } + + /** + * This throws an exception if it can't set the position to what you give it. + */ + def seekToPosition(p: Long) { + if (p < pos) illegal(s"Can't seek backwards, at position $pos, trying to goto $p") + wraps.skipFully(p - pos) + pos = p + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Reader.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Reader.scala new file mode 100644 index 0000000000..0e1a4e6199 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Reader.scala @@ -0,0 +1,130 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import java.io.InputStream +import scala.reflect.ClassTag +import scala.collection.generic.CanBuildFrom + +/** + * This is a specialized typeclass to make it easier to implement Serializations. + * The specialization *should* mean that there is no boxing and if the JIT + * does its work, Reader should compose well (via collections, Tuple2, Option, Either) + */ +trait Reader[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) +T] { + def read(is: InputStream): T +} + +object Reader { + import JavaStreamEnrichments._ + + def read[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T]( + is: InputStream)(implicit r: Reader[T]): T = r.read(is) + /* + * Instances below + */ + implicit val unit: Reader[Unit] = new Reader[Unit] { + def read(is: InputStream) = () + } + implicit val boolean: Reader[Boolean] = new Reader[Boolean] { + def read(is: InputStream) = is.readBoolean + } + implicit val byte: Reader[Byte] = new Reader[Byte] { + def read(is: InputStream) = is.readByte + } + implicit val short: Reader[Short] = new Reader[Short] { + def read(is: InputStream) = is.readShort + } + implicit val int: Reader[Int] = new Reader[Int] { + def read(is: InputStream) = is.readInt + } + implicit val long: Reader[Long] = new Reader[Long] { + def read(is: InputStream) = is.readLong + } + implicit val float: Reader[Float] = new Reader[Float] { + def read(is: InputStream) = is.readFloat + } + implicit val double: Reader[Double] = new Reader[Double] { + def read(is: InputStream) = is.readDouble + } + implicit val string: Reader[String] = new Reader[String] { + def read(is: InputStream) = { + val size = is.readPosVarInt + val bytes = new Array[Byte](size) + is.readFully(bytes) + new String(bytes, "UTF-8") + } + } + + implicit def option[T: Reader]: Reader[Option[T]] = new Reader[Option[T]] { + val r = implicitly[Reader[T]] + def read(is: InputStream) = + if (is.readByte == (0: Byte)) None + else Some(r.read(is)) + } + + implicit def either[L: Reader, R: Reader]: Reader[Either[L, R]] = new Reader[Either[L, R]] { + val lRead = implicitly[Reader[L]] + val rRead = implicitly[Reader[R]] + def read(is: InputStream) = + if (is.readByte == (0: Byte)) Left(lRead.read(is)) + else Right(rRead.read(is)) + } + + implicit def tuple2[T1: Reader, T2: Reader]: Reader[(T1, T2)] = new Reader[(T1, T2)] { + val r1 = implicitly[Reader[T1]] + val r2 = implicitly[Reader[T2]] + def read(is: InputStream) = (r1.read(is), r2.read(is)) + } + + implicit def array[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T: Reader: ClassTag]: Reader[Array[T]] = + new Reader[Array[T]] { + val readerT = implicitly[Reader[T]] + def read(is: InputStream) = { + val size = is.readPosVarInt + val res = new Array[T](size) + @annotation.tailrec + def go(p: Int): Unit = + if (p == size) () + else { + res(p) = readerT.read(is) + go(p + 1) + } + go(0) + res + } + } + + // Scala seems to have issues with this being implicit + def collection[T: Reader, C](implicit cbf: CanBuildFrom[Nothing, T, C]): Reader[C] = new Reader[C] { + val readerT = implicitly[Reader[T]] + def read(is: InputStream): C = { + val builder = cbf() + val size = is.readPosVarInt + builder.sizeHint(size) + @annotation.tailrec + def go(idx: Int): Unit = + if (idx == size) () + else { + builder += readerT.read(is) + go(idx + 1) + } + + go(0) + builder.result + } + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala new file mode 100644 index 0000000000..b038253432 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala @@ -0,0 +1,167 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +import java.io.{ ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream, Serializable } + +import scala.util.{ Success, Try } +import scala.util.hashing.Hashing + +/** + * This is a base Input/OutputStream-based serialization typeclass + * This is useful for value serialization in hadoop when we don't + * need to do key sorting for partitioning. + * + * This serialization typeclass must serialize equivalent objects + * identically to be lawful. Serialization should be the same + * on all JVMs run at any time, in other words, Serialization is a + * pure function. Given that constraint, we can always + * get an Equiv and Hashing from a Serialization (by doing byte-wise + * equivalence or byte-wise hashing). + * + * A serialization always gives a hash because one can just + * serialize and then hash the bytes. You might prefer another + * implementation. This must satisfy: + * (!equiv(a, b)) || (hash(a) == hash(b)) + */ +trait Serialization[T] extends Equiv[T] with Hashing[T] with Serializable { + def read(in: InputStream): Try[T] + def write(out: OutputStream, t: T): Try[Unit] + /** + * If all items have a static size, this returns Some, else None + * NOTE: lawful implementations that return Some here much return + * Some on dynamicSize so callers don't need to check both when + * they have an instance. + */ + def staticSize: Option[Int] + /** + * returns Some if the size is cheap to calculate. + * otherwise the caller should just serialize into an ByteArrayOutputStream + */ + def dynamicSize(t: T): Option[Int] +} + +object Serialization { + import JavaStreamEnrichments._ + /** + * This is a constant for us to reuse in Serialization.write + */ + val successUnit: Try[Unit] = Success(()) + + def equiv[T](a: T, b: T)(implicit ser: Serialization[T]): Boolean = + ser.equiv(a, b) + + def hash[T](t: T)(implicit ser: Serialization[T]): Int = + ser.hash(t) + + def read[T](in: InputStream)(implicit ser: Serialization[T]): Try[T] = + ser.read(in) + + def write[T](out: OutputStream, t: T)(implicit ser: Serialization[T]): Try[Unit] = + ser.write(out, t) + + def toBytes[T](t: T)(implicit ser: Serialization[T]): Array[Byte] = { + ser.dynamicSize(t) match { + case None => + val baos = new ByteArrayOutputStream + write(baos, t).get // this should only throw on OOM + baos.toByteArray + case Some(size) => + // If we know the size, we can just write directly into a fixed + // size byte array + val bytes = new Array[Byte](size) + val os = bytes.wrapAsOutputStream + write(os, t).get // this should only throw on OOM + bytes + } + } + + def fromBytes[T: Serialization](b: Array[Byte]): Try[T] = + read(new ByteArrayInputStream(b)) + + /** + * This copies more than needed, but it is only for testing + */ + private def roundTrip[T](t: T)(implicit ser: Serialization[T]): T = { + val baos = new ByteArrayOutputStream + ser.write(baos, t).get // should never throw on a ByteArrayOutputStream + ser.read(baos.toInputStream).get + } + + /** + * Do these two items write equivalently? + */ + def writeEquiv[T: Serialization](a: T, b: T): Boolean = + java.util.Arrays.equals(toBytes(a), toBytes(b)) + + /** + * write followed by read should give an equivalent T + * + * This is a law that serialization must follow. It is here for + * documentation and for use within tests without any dependence on + * specific test frameworks. + * + * forAll(roundTripLaw[T]) in a valid test in scalacheck style + */ + def roundTripLaw[T: Serialization]: Law1[T] = + Law1("roundTrip", { (t: T) => equiv(roundTrip(t), t) }) + + /** + * If two items are equal, they should serialize byte for byte equivalently + */ + def serializationIsEquivalence[T: Serialization]: Law2[T] = + Law2("equiv(a, b) == (write(a) == write(b))", { (t1: T, t2: T) => + equiv(t1, t2) == writeEquiv(t1, t2) + }) + + def hashCodeImpliesEquality[T: Serialization]: Law2[T] = + Law2("equiv(a, b) => hash(a) == hash(b)", { (t1: T, t2: T) => + !equiv(t1, t2) || (hash(t1) == hash(t2)) + }) + + def reflexivity[T: Serialization]: Law1[T] = + Law1("equiv(a, a) == true", { (t1: T) => equiv(t1, t1) }) + + /** + * The sizes must match and be correct if they are present + */ + def sizeLaw[T: Serialization]: Law1[T] = + Law1("staticSize.orElse(dynamicSize(t)).map { _ == toBytes(t).length }", + { (t: T) => + val ser = implicitly[Serialization[T]] + (ser.staticSize, ser.dynamicSize(t)) match { + case (Some(s), Some(d)) if d == s => toBytes(t).length == s + case (Some(s), _) => false // if static exists it must match dynamic + case (None, Some(d)) => toBytes(t).length == d + case (None, None) => true // can't tell + } + }) + + def transitivity[T: Serialization]: Law3[T] = + Law3("equiv(a, b) && equiv(b, c) => equiv(a, c)", + { (t1: T, t2: T, t3: T) => + !(equiv(t1, t2) && equiv(t2, t3)) || equiv(t1, t3) + }) + + def allLaws[T: Serialization]: Iterable[Law[T]] = + List(roundTripLaw, + serializationIsEquivalence, + hashCodeImpliesEquality, + reflexivity, + sizeLaw, + transitivity) +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization2.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization2.scala new file mode 100644 index 0000000000..e364476854 --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization2.scala @@ -0,0 +1,79 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +import java.io.{ InputStream, OutputStream } + +import scala.util.{ Failure, Success, Try } + +class Serialization2[A, B](val serA: Serialization[A], val serB: Serialization[B]) extends Serialization[(A, B)] { + override def hash(x: (A, B)) = { + import MurmurHashUtils._ + val h1 = mixH1(seed, serA.hash(x._1)) + val h2 = mixH1(h1, serB.hash(x._2)) + fmix(h2, 2) + } + override def equiv(x: (A, B), y: (A, B)): Boolean = + serA.equiv(x._1, y._1) && serB.equiv(x._2, y._2) + + override def read(in: InputStream): Try[(A, B)] = { + val a = serA.read(in) + val b = serB.read(in) + (a, b) match { + case (Success(a), Success(b)) => Success((a, b)) + case (Failure(e), _) => Failure(e) + case (_, Failure(e)) => Failure(e) + } + } + + override def write(out: OutputStream, a: (A, B)): Try[Unit] = { + val resA = serA.write(out, a._1) + if (resA.isSuccess) serB.write(out, a._2) + else resA + } + + override val staticSize = for { + a <- serA.staticSize + b <- serB.staticSize + } yield a + b + + override def dynamicSize(t: (A, B)) = if (staticSize.isDefined) staticSize + else for { + a <- serA.dynamicSize(t._1) + b <- serB.dynamicSize(t._2) + } yield a + b +} + +class OrderedSerialization2[A, B](val ordA: OrderedSerialization[A], + val ordB: OrderedSerialization[B]) extends Serialization2[A, B](ordA, ordB) with OrderedSerialization[(A, B)] { + override def compare(x: (A, B), y: (A, B)) = { + val ca = ordA.compare(x._1, y._1) + if (ca != 0) ca + else ordB.compare(x._2, y._2) + } + override def compareBinary(a: InputStream, b: InputStream) = { + // This mutates the buffers and advances them. Only keep reading if they are different + val cA = ordA.compareBinary(a, b) + // we have to read the second ones to skip + val cB = ordB.compareBinary(a, b) + cA match { + case OrderedSerialization.Equal => cB + case f @ OrderedSerialization.CompareFailure(_) => f + case _ => cA // the first is not equal + } + } +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/StringOrderedSerialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/StringOrderedSerialization.scala new file mode 100644 index 0000000000..23ab371c4d --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/StringOrderedSerialization.scala @@ -0,0 +1,115 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +import java.io.{ InputStream, OutputStream } +import scala.util.{ Failure, Success } +import scala.util.control.NonFatal + +import JavaStreamEnrichments._ + +object StringOrderedSerialization { + final def binaryIntCompare(leftSize: Int, seekingLeft: InputStream, rightSize: Int, seekingRight: InputStream): Int = { + /* + * This algorithm only works if count in {0, 1, 2, 3}. Since we only + * call it that way below it is safe. + */ + + @inline + def compareBytes(count: Int): Int = + if ((count & 2) == 2) { + // there are 2 or 3 bytes to read + val cmp = Integer.compare(seekingLeft.readUnsignedShort, + seekingRight.readUnsignedShort) + if (cmp != 0) cmp + else if (count == 3) Integer.compare(seekingLeft.readUnsignedByte, + seekingRight.readUnsignedByte) + else 0 + } else { + // there are 0 or 1 bytes to read + if (count == 0) 0 + else Integer.compare(seekingLeft.readUnsignedByte, + seekingRight.readUnsignedByte) + } + + /** + * Now we start by comparing blocks of ints, then 0 - 3 bytes + */ + val toCheck = math.min(leftSize, rightSize) + val ints = toCheck / 4 + var counter = ints + var ic = 0 + while ((counter > 0) && (ic == 0)) { + // Unsigned compare of ints is cheaper than longs, because we can do it + // by upcasting to Long + ic = UnsignedComparisons.unsignedIntCompare(seekingLeft.readInt, seekingRight.readInt) + counter = counter - 1 + } + if (ic != 0) ic + else { + val bc = compareBytes(toCheck - 4 * ints) + if (bc != 0) bc + else { + // the size is the fallback when the prefixes match: + Integer.compare(leftSize, rightSize) + } + } + } +} + +class StringOrderedSerialization extends OrderedSerialization[String] { + import StringOrderedSerialization._ + override def hash(s: String) = s.hashCode + override def compare(a: String, b: String) = a.compareTo(b) + override def read(in: InputStream) = try { + val byteString = new Array[Byte](in.readPosVarInt) + in.readFully(byteString) + Success(new String(byteString, "UTF-8")) + } catch { case NonFatal(e) => Failure(e) } + + override def write(b: OutputStream, s: String) = try { + val bytes = s.getBytes("UTF-8") + b.writePosVarInt(bytes.length) + b.writeBytes(bytes) + Serialization.successUnit + } catch { case NonFatal(e) => Failure(e) } + + override def compareBinary(lhs: InputStream, rhs: InputStream) = try { + val leftSize = lhs.readPosVarInt + val rightSize = rhs.readPosVarInt + + val seekingLeft = PositionInputStream(lhs) + val seekingRight = PositionInputStream(rhs) + + val leftStart = seekingLeft.position + val rightStart = seekingRight.position + + val res = OrderedSerialization.resultFrom(binaryIntCompare(leftSize, seekingLeft, rightSize, seekingRight)) + seekingLeft.seekToPosition(leftStart + leftSize) + seekingRight.seekToPosition(rightStart + rightSize) + res + } catch { + case NonFatal(e) => OrderedSerialization.CompareFailure(e) + } + /** + * generally there is no way to see how big a utf-8 string is without serializing. + * We could scan looking for all ascii characters, but it's hard to see if + * we'd get the balance right. + */ + override def staticSize = None + override def dynamicSize(s: String) = None +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/UnsignedComparisons.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/UnsignedComparisons.scala new file mode 100644 index 0000000000..86c0839c7d --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/UnsignedComparisons.scala @@ -0,0 +1,36 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.scalding.serialization + +object UnsignedComparisons { + final def unsignedLongCompare(a: Long, b: Long): Int = if (a == b) 0 else { + val xor = (a ^ b) + // If xor >= 0, then a and b are on the same side of zero + if (xor >= 0L) java.lang.Long.compare(a, b) + else if (b >= 0L) 1 + else -1 + } + final def unsignedIntCompare(a: Int, b: Int): Int = + java.lang.Long.compare(a.toLong & 0xFFFFFFFFL, b.toLong & 0xFFFFFFFFL) + + final def unsignedShortCompare(a: Short, b: Short): Int = + Integer.compare(a & 0xFFFF, b & 0xFFFF) + + final def unsignedByteCompare(a: Byte, b: Byte): Int = + Integer.compare(a & 0xFF, b & 0xFF) +} + diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Writer.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Writer.scala new file mode 100644 index 0000000000..27436ed26c --- /dev/null +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Writer.scala @@ -0,0 +1,128 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import java.io.OutputStream + +/** + * This is a specialized typeclass to make it easier to implement Serializations. + * The specialization *should* mean that there is no boxing and if the JIT + * does its work, Writer should compose well (via collections, Tuple2, Option, Either) + */ +trait Writer[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) -T] { + def write(os: OutputStream, t: T): Unit +} + +object Writer { + import JavaStreamEnrichments._ + + def write[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T](os: OutputStream, + t: T)(implicit w: Writer[T]): Unit = + w.write(os, t) + /* + * Instances below + */ + implicit val unit: Writer[Unit] = new Writer[Unit] { + def write(os: OutputStream, u: Unit) = () + } + implicit val boolean: Writer[Boolean] = new Writer[Boolean] { + def write(os: OutputStream, b: Boolean) = os.writeBoolean(b) + } + implicit val byte: Writer[Byte] = new Writer[Byte] { + def write(os: OutputStream, b: Byte) = os.write(b) + } + implicit val short: Writer[Short] = new Writer[Short] { + def write(os: OutputStream, s: Short) = os.writeShort(s) + } + implicit val int: Writer[Int] = new Writer[Int] { + def write(os: OutputStream, s: Int) = os.writeInt(s) + } + implicit val long: Writer[Long] = new Writer[Long] { + def write(os: OutputStream, s: Long) = os.writeLong(s) + } + implicit val float: Writer[Float] = new Writer[Float] { + def write(os: OutputStream, s: Float) = os.writeFloat(s) + } + implicit val double: Writer[Double] = new Writer[Double] { + def write(os: OutputStream, s: Double) = os.writeDouble(s) + } + implicit val string: Writer[String] = new Writer[String] { + def write(os: OutputStream, s: String) = { + val bytes = s.getBytes("UTF-8") + os.writePosVarInt(bytes.length) + os.writeBytes(bytes) + } + } + + implicit def option[T: Writer]: Writer[Option[T]] = new Writer[Option[T]] { + val w = implicitly[Writer[T]] + def write(os: OutputStream, t: Option[T]) = + if (t.isDefined) { + os.write(1: Byte) + w.write(os, t.get) + } else os.write(0: Byte) + } + + implicit def either[L: Writer, R: Writer]: Writer[Either[L, R]] = new Writer[Either[L, R]] { + val lw = implicitly[Writer[L]] + val rw = implicitly[Writer[R]] + def write(os: OutputStream, e: Either[L, R]) = e match { + case Left(l) => + os.write(0: Byte) + lw.write(os, l) + case Right(r) => + os.write(1: Byte) + rw.write(os, r) + } + } + + implicit def tuple2[T1: Writer, T2: Writer]: Writer[(T1, T2)] = new Writer[(T1, T2)] { + val w1 = implicitly[Writer[T1]] + val w2 = implicitly[Writer[T2]] + def write(os: OutputStream, tup: (T1, T2)) = { + w1.write(os, tup._1) + w2.write(os, tup._2) + } + } + + implicit def array[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T: Writer]: Writer[Array[T]] = + new Writer[Array[T]] { + val writerT = implicitly[Writer[T]] + def write(os: OutputStream, a: Array[T]) = { + val size = a.length + os.writePosVarInt(size) + @annotation.tailrec + def go(p: Int): Unit = + if (p == size) () + else { writerT.write(os, a(p)); go(p + 1) } + + go(0) + } + } + + // Scala has problems with this being implicit + def collection[T: Writer, C <: Iterable[T]]: Writer[C] = new Writer[C] { + val writerT = implicitly[Writer[T]] + def write(os: OutputStream, c: C) = { + val size = c.size + os.writePosVarInt(size) + c.foreach { t: T => + writerT.write(os, t) + } + } + } +} + diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/JavaStreamEnrichmentsProperties.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/JavaStreamEnrichmentsProperties.scala new file mode 100644 index 0000000000..8c95db542f --- /dev/null +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/JavaStreamEnrichmentsProperties.scala @@ -0,0 +1,99 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import org.scalacheck.Arbitrary +import org.scalacheck.Properties +import org.scalacheck.Prop +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import JavaStreamEnrichments._ +import java.io._ + +import scala.collection.generic.CanBuildFrom + +object JavaStreamEnrichmentsProperties extends Properties("JavaStreamEnrichmentsProperties") { + + def output = new ByteArrayOutputStream + + // The default Array[Equiv] is reference. WAT!? + implicit def aeq[T: Equiv]: Equiv[Array[T]] = new Equiv[Array[T]] { + def equiv(a: Array[T], b: Array[T]): Boolean = { + val teq = Equiv[T] + @annotation.tailrec + def go(pos: Int): Boolean = + if (pos == a.length) true + else { + teq.equiv(a(pos), b(pos)) && go(pos + 1) + } + + (a.length == b.length) && go(0) + } + } + implicit def teq[T1: Equiv, T2: Equiv]: Equiv[(T1, T2)] = new Equiv[(T1, T2)] { + def equiv(a: (T1, T2), b: (T1, T2)) = { + Equiv[T1].equiv(a._1, b._1) && + Equiv[T2].equiv(a._2, b._2) + } + } + + def writeRead[T: Equiv](g: Gen[T], w: (T, OutputStream) => Unit, r: InputStream => T): Prop = + forAll(g) { t => + val test = output + w(t, test) + Equiv[T].equiv(r(test.toInputStream), t) + } + def writeRead[T: Equiv: Arbitrary](w: (T, OutputStream) => Unit, r: InputStream => T): Prop = + writeRead(implicitly[Arbitrary[T]].arbitrary, w, r) + + property("Can (read/write)Size") = writeRead(Gen.chooseNum(0, Int.MaxValue), + { (i: Int, os) => os.writePosVarInt(i) }, { _.readPosVarInt }) + + property("Can (read/write)Float") = writeRead( + { (i: Float, os) => os.writeFloat(i) }, { _.readFloat }) + + property("Can (read/write)Array[Byte]") = writeRead( + // Use list because Array has a shitty toString + { (b: List[Byte], os) => os.writePosVarInt(b.size); os.writeBytes(b.toArray) }, + { is => + val bytes = new Array[Byte](is.readPosVarInt) + is.readFully(bytes) + bytes.toList + }) + + property("Can (read/write)Boolean") = writeRead( + { (i: Boolean, os) => os.writeBoolean(i) }, { _.readBoolean }) + + property("Can (read/write)Double") = writeRead( + { (i: Double, os) => os.writeDouble(i) }, { _.readDouble }) + + property("Can (read/write)Int") = writeRead(Gen.chooseNum(Int.MinValue, Int.MaxValue), + { (i: Int, os) => os.writeInt(i) }, { _.readInt }) + + property("Can (read/write)Long") = writeRead(Gen.chooseNum(Long.MinValue, Long.MaxValue), + { (i: Long, os) => os.writeLong(i) }, { _.readLong }) + + property("Can (read/write)Short") = writeRead(Gen.chooseNum(Short.MinValue, Short.MaxValue), + { (i: Short, os) => os.writeShort(i) }, { _.readShort }) + + property("Can (read/write)UnsignedByte") = writeRead(Gen.chooseNum(0, (1 << 8) - 1), + { (i: Int, os) => os.write(i.toByte) }, { _.readUnsignedByte }) + + property("Can (read/write)UnsignedShort") = writeRead(Gen.chooseNum(0, (1 << 16) - 1), + { (i: Int, os) => os.writeShort(i.toShort) }, { _.readUnsignedShort }) +} diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/SerializationProperties.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/SerializationProperties.scala new file mode 100644 index 0000000000..c382e3c554 --- /dev/null +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/SerializationProperties.scala @@ -0,0 +1,129 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import org.scalacheck.Arbitrary +import org.scalacheck.Properties +import org.scalacheck.Prop +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import JavaStreamEnrichments._ +import java.io._ +import scala.util.Try + +object LawTester { + def apply[T: Arbitrary](base: String, laws: Iterable[Law[T]]): Properties = + new LawTester(implicitly[Arbitrary[T]].arbitrary, base, laws) {} +} + +abstract class LawTester[T](g: Gen[T], base: String, laws: Iterable[Law[T]]) extends Properties(base) { + laws.foreach { + case Law1(name, fn) => property(name) = forAll(g)(fn) + case Law2(name, fn) => property(name) = forAll(g, g)(fn) + case Law3(name, fn) => property(name) = forAll(g, g, g)(fn) + } +} + +object SerializationProperties extends Properties("SerializationProperties") { + + import OrderedSerialization.{ resultFrom, CompareFailure, readThenCompare } + + implicit val intOrderedSerialization: OrderedSerialization[Int] = new OrderedSerialization[Int] { + def read(in: InputStream) = Try(Reader.read[Int](in)) + def write(o: OutputStream, t: Int) = Try(Writer.write[Int](o, t)) + def hash(t: Int) = t.hashCode + def compare(a: Int, b: Int) = java.lang.Integer.compare(a, b) + def compareBinary(a: InputStream, b: InputStream) = + readThenCompare(a, b)(this) + val staticSize = Some(4) + def dynamicSize(i: Int) = staticSize + } + + implicit val stringOrdSer: OrderedSerialization[String] = new StringOrderedSerialization + + implicit def tuple[A: OrderedSerialization, B: OrderedSerialization]: OrderedSerialization[(A, B)] = + new OrderedSerialization2[A, B](implicitly, implicitly) + + def serializeSequenceCompare[T: OrderedSerialization](g: Gen[T]): Prop = forAll(Gen.listOf(g)) { list => + // make sure the list is even in size: + val pairList = (if (list.size % 2 == 1) list.tail else list).grouped(2) + val baos1 = new ByteArrayOutputStream + val baos2 = new ByteArrayOutputStream + pairList.foreach { + case Seq(a, b) => + Serialization.write(baos1, a) + Serialization.write(baos2, b) + case _ => sys.error("unreachable") + } + // now the compares must match: + val in1 = baos1.toInputStream + val in2 = baos2.toInputStream + pairList.forall { + case Seq(a, b) => + OrderedSerialization.compareBinary[T](in1, in2) == + OrderedSerialization.resultFrom(OrderedSerialization.compare(a, b)) + case _ => sys.error("unreachable") + } + } + + def serializeSequenceCompare[T: OrderedSerialization: Arbitrary]: Prop = + serializeSequenceCompare[T](implicitly[Arbitrary[T]].arbitrary) + + def serializeSequenceEquiv[T: Serialization](g: Gen[T]): Prop = forAll(Gen.listOf(g)) { list => + // make sure the list is even in size: + val pairList = (if (list.size % 2 == 1) list.tail else list).grouped(2) + val baos1 = new ByteArrayOutputStream + val baos2 = new ByteArrayOutputStream + pairList.foreach { + case Seq(a, b) => + Serialization.write(baos1, a) + Serialization.write(baos2, b) + case _ => sys.error("unreachable") + } + // now the compares must match: + val in1 = baos1.toInputStream + val in2 = baos2.toInputStream + pairList.forall { + case Seq(a, b) => + val rta = Serialization.read[T](in1).get + val rtb = Serialization.read[T](in2).get + Serialization.equiv(a, rta) && Serialization.equiv(b, rtb) + case _ => sys.error("unreachable") + } + } + def serializeSequenceEquiv[T: Serialization: Arbitrary]: Prop = + serializeSequenceEquiv[T](implicitly[Arbitrary[T]].arbitrary) + + property("sequences compare well [Int]") = serializeSequenceCompare[Int] + property("sequences equiv well [Int]") = serializeSequenceEquiv[Int] + property("sequences compare well [(Int, Int)]") = serializeSequenceCompare[(Int, Int)] + property("sequences equiv well [(Int, Int)]") = serializeSequenceEquiv[(Int, Int)] + + property("sequences compare well [String]") = serializeSequenceCompare[String] + property("sequences equiv well [String]") = serializeSequenceEquiv[String] + property("sequences compare well [(String, String)]") = serializeSequenceCompare[(String, String)] + property("sequences equiv well [(String, String)]") = serializeSequenceEquiv[(String, String)] + + // Test the independent, non-sequenced, laws as well + include(LawTester("Int Ordered", OrderedSerialization.allLaws[Int])) + include(LawTester("(Int, Int) Ordered", OrderedSerialization.allLaws[(Int, Int)])) + include(LawTester("String Ordered", OrderedSerialization.allLaws[String])) + include(LawTester("(String, Int) Ordered", OrderedSerialization.allLaws[(String, Int)])) + include(LawTester("(Int, String) Ordered", OrderedSerialization.allLaws[(Int, String)])) + include(LawTester("(String, String) Ordered", OrderedSerialization.allLaws[(String, String)])) +} diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/UnsignedComparisonLaws.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/UnsignedComparisonLaws.scala new file mode 100644 index 0000000000..bb258d61af --- /dev/null +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/UnsignedComparisonLaws.scala @@ -0,0 +1,41 @@ +package com.twitter.scalding.serialization + +import org.scalacheck.Arbitrary +import org.scalacheck.Properties +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen.choose +import org.scalacheck.Prop._ + +object UnsignedComparisonLaws extends Properties("UnsignedComparisonLaws") { + + property("UnsignedLongCompare works") = forAll { (l1: Long, l2: Long) => + val cmp = UnsignedComparisons.unsignedLongCompare(l1, l2) + (l1 >= 0, l2 >= 0) match { + case (true, true) => cmp == java.lang.Long.compare(l1, l2) + case (true, false) => cmp < 0 // negative is bigger + case (false, true) => cmp > 0 + case (false, false) => cmp == java.lang.Long.compare(l1 & Long.MaxValue, l2 & Long.MaxValue) + } + } + property("UnsignedIntCompare works") = forAll { (l1: Int, l2: Int) => + val cmp = UnsignedComparisons.unsignedIntCompare(l1, l2) + (l1 >= 0, l2 >= 0) match { + case (true, true) => cmp == java.lang.Integer.compare(l1, l2) + case (true, false) => cmp < 0 // negative is bigger + case (false, true) => cmp > 0 + case (false, false) => cmp == java.lang.Integer.compare(l1 & Int.MaxValue, l2 & Int.MaxValue) + } + } + property("UnsignedByteCompare works") = forAll { (l1: Byte, l2: Byte) => + def clamp(i: Int) = if (i > 0) 1 else if (i < 0) -1 else 0 + val cmp = clamp(UnsignedComparisons.unsignedByteCompare(l1, l2)) + (l1 >= 0, l2 >= 0) match { + case (true, true) => cmp == clamp(java.lang.Byte.compare(l1, l2)) + case (true, false) => cmp < 0 // negative is bigger + case (false, true) => cmp > 0 + // Convert to positive ints + case (false, false) => cmp == java.lang.Integer.compare(l1 & Byte.MaxValue, l2 & Byte.MaxValue) + } + } +} + diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/WriterReaderProperties.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/WriterReaderProperties.scala new file mode 100644 index 0000000000..a4e9fab9fd --- /dev/null +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/WriterReaderProperties.scala @@ -0,0 +1,102 @@ +/* +Copyright 2015 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.serialization + +import org.scalacheck.Arbitrary +import org.scalacheck.Properties +import org.scalacheck.Prop +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import JavaStreamEnrichments._ +import java.io._ + +import scala.collection.generic.CanBuildFrom + +object WriterReaderProperties extends Properties("WriterReaderProperties") { + + def output = new ByteArrayOutputStream + + // The default Array[Equiv] is reference. WAT!? + implicit def aeq[T: Equiv]: Equiv[Array[T]] = new Equiv[Array[T]] { + def equiv(a: Array[T], b: Array[T]): Boolean = { + val teq = Equiv[T] + @annotation.tailrec + def go(pos: Int): Boolean = + if (pos == a.length) true + else { + teq.equiv(a(pos), b(pos)) && go(pos + 1) + } + + (a.length == b.length) && go(0) + } + } + implicit def teq[T1: Equiv, T2: Equiv]: Equiv[(T1, T2)] = new Equiv[(T1, T2)] { + def equiv(a: (T1, T2), b: (T1, T2)) = { + Equiv[T1].equiv(a._1, b._1) && + Equiv[T2].equiv(a._2, b._2) + } + } + + def writerReader[T: Writer: Reader: Equiv](g: Gen[T]): Prop = + forAll(g) { t => + val test = output + Writer.write(test, t) + Equiv[T].equiv(Reader.read(test.toInputStream), t) + } + def writerReader[T: Writer: Reader: Equiv: Arbitrary]: Prop = + writerReader(implicitly[Arbitrary[T]].arbitrary) + + def writerReaderCollection[T: Writer: Reader, C <: Iterable[T]: Arbitrary: Equiv](implicit cbf: CanBuildFrom[Nothing, T, C]): Prop = + { + implicit val cwriter = Writer.collection[T, C] + implicit val creader = Reader.collection[T, C] + writerReader(implicitly[Arbitrary[C]].arbitrary) + } + + /* + * Test the Writer/Reader type-classes + */ + property("Unit Writer/Reader") = writerReader[Unit] + property("Boolean Writer/Reader") = writerReader[Boolean] + property("Byte Writer/Reader") = writerReader[Byte] + property("Short Writer/Reader") = writerReader[Short] + property("Int Writer/Reader") = writerReader[Int] + property("Long Writer/Reader") = writerReader[Long] + property("Float Writer/Reader") = writerReader[Float] + property("Double Writer/Reader") = writerReader[Double] + property("String Writer/Reader") = writerReader[String] + property("Array[Byte] Writer/Reader") = writerReader[Array[Byte]] + property("Array[Int] Writer/Reader") = writerReader[Array[Int]] + property("Array[String] Writer/Reader") = writerReader[Array[String]] + property("List[String] Writer/Reader") = + writerReaderCollection[String, List[String]] + property("(Int, Array[String]) Writer/Reader") = + writerReader[(Int, Array[String])] + + property("Option[(Int, Double)] Writer/Reader") = + writerReader[Option[(Int, Double)]] + + property("Option[Option[Unit]] Writer/Reader") = + writerReader[Option[Option[Unit]]] + + property("Either[Int, String] Writer/Reader") = + writerReader[Either[Int, String]] + + property("Map[Long, Byte] Writer/Reader") = + writerReaderCollection[(Long, Byte), Map[Long, Byte]] +} diff --git a/scripts/test_execution_tutorial.sh b/scripts/test_execution_tutorial.sh new file mode 100755 index 0000000000..a2178482e2 --- /dev/null +++ b/scripts/test_execution_tutorial.sh @@ -0,0 +1,24 @@ +set -e # first error should stop execution of this script + +# Identify the bin dir in the distribution, and source the common include script +BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )"/.. && pwd )" +source ${BASE_DIR}/scripts/common.sh +SHORT_SCALA_VERSION=${TRAVIS_SCALA_VERSION%.*} +SCALDING_VERSION=`cat ${BASE_DIR}/version.sbt` +SCALDING_VERSION=${SCALDING_VERSION#*\"} +SCALDING_VERSION=${SCALDING_VERSION%\"} + + +# also trap errors, to reenable terminal settings +trap onExit ERR +export CLASSPATH=tutorial/execution-tutorial/target/scala-${SHORT_SCALA_VERSION}/execution-tutorial-assembly-${SCALDING_VERSION}.jar +time java -jar tutorial/execution-tutorial/target/scala-${SHORT_SCALA_VERSION}/execution-tutorial-assembly-${SCALDING_VERSION}.jar \ + com.twitter.scalding.tutorial.MyExecJob --local \ + --input tutorial/data/hello.txt \ + --output tutorial/data/execution_output.txt + +# restore stty +SCALA_EXIT_STATUS=0 +onExit + + diff --git a/tutorial/execution-tutorial/ExecutionTutorial.scala b/tutorial/execution-tutorial/ExecutionTutorial.scala new file mode 100644 index 0000000000..2cb786d077 --- /dev/null +++ b/tutorial/execution-tutorial/ExecutionTutorial.scala @@ -0,0 +1,66 @@ +/* +Copyright 2012 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.tutorial + +import java.io._ +import scala.util.{Failure, Success} + +import com.twitter.scalding._ + +/** +Tutorial of using Execution + +This tutorial gives an example of use Execution to do MapReduce word count. +Instead of writing the results in reducers, it writes the data at submitter node. + +To test it, first build the assembly jar from root directory: + ./sbt execution-tutorial/assembly + +Run: + scala -classpath tutorial/execution-tutorial/target/execution-tutorial-assembly-0.15.0.jar \ + com.twitter.scalding.tutorial.MyExecJob --local \ + --input tutorial/data/hello.txt \ + --output tutorial/data/execution_output.txt +**/ + +object MyExecJob extends ExecutionApp { + + override def job = Execution.getConfig.flatMap { config => + val args = config.getArgs + + TypedPipe.from(TextLine(args("input"))) + .flatMap(_.split("\\s+")) + .map((_, 1L)) + .sumByKey + .toIterableExecution + // toIterableExecution will materialize the outputs to submitter node when finish. + // We can also write the outputs on HDFS via .writeExecution(TypedTsv(args("output"))) + .onComplete { t => t match { + case Success(iter) => + val file = new PrintWriter(new File(args("output"))) + iter.foreach { case (k, v) => + file.write(s"$k\t$v\n") + } + file.close + case Failure(e) => println("Error: " + e.toString) + } + } + // use the result and map it to a Unit. Otherwise the onComplete call won't happen + .unit + } +} + + diff --git a/version.sbt b/version.sbt index beb9953b43..b2771948a7 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "0.13.1" +version in ThisBuild := "0.15.0"