From a28728a9afcff94194147573e07f6f4d0463687e Mon Sep 17 00:00:00 2001 From: goldmedal Date: Fri, 15 Sep 2017 11:53:10 +0900 Subject: [PATCH 01/37] [SPARK-21513][SQL][FOLLOWUP] Allow UDF to_json support converting MapType to json for PySpark and SparkR ## What changes were proposed in this pull request? In previous work SPARK-21513, we has allowed `MapType` and `ArrayType` of `MapType`s convert to a json string but only for Scala API. In this follow-up PR, we will make SparkSQL support it for PySpark and SparkR, too. We also fix some little bugs and comments of the previous work in this follow-up PR. ### For PySpark ``` >>> data = [(1, {"name": "Alice"})] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'{"name":"Alice")'] >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] ``` ### For SparkR ``` # Converts a map into a JSON object df2 <- sql("SELECT map('name', 'Bob')) as people") df2 <- mutate(df2, people_json = to_json(df2$people)) # Converts an array of maps into a JSON array df2 <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") df2 <- mutate(df2, people_json = to_json(df2$people)) ``` ## How was this patch tested? Add unit test cases. cc viirya HyukjinKwon Author: goldmedal Closes #19223 from goldmedal/SPARK-21513-fp-PySaprkAndSparkR. --- R/pkg/R/functions.R | 16 +++++++++++--- R/pkg/tests/fulltests/test_sparkSQL.R | 8 +++++++ python/pyspark/sql/functions.py | 22 ++++++++++++++----- .../expressions/jsonExpressions.scala | 8 +++---- .../sql/catalyst/json/JacksonGenerator.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 8 +++---- 6 files changed, 46 insertions(+), 18 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5a46d737aeeb7..e92e1fd72bf10 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -176,7 +176,8 @@ NULL #' #' @param x Column to compute on. Note the difference in the following methods: #' \itemize{ -#' \item \code{to_json}: it is the column containing the struct or array of the structs. +#' \item \code{to_json}: it is the column containing the struct, array of the structs, +#' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains @@ -1700,8 +1701,9 @@ setMethod("to_date", }) #' @details -#' \code{to_json}: Converts a column containing a \code{structType} or array of \code{structType} -#' into a Column of JSON string. Resolving the Column can fail if an unsupported type is encountered. +#' \code{to_json}: Converts a column containing a \code{structType}, array of \code{structType}, +#' a \code{mapType} or array of \code{mapType} into a Column of JSON string. +#' Resolving the Column can fail if an unsupported type is encountered. #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method @@ -1715,6 +1717,14 @@ setMethod("to_date", #' #' # Converts an array of structs into a JSON array #' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts a map into a JSON object +#' df2 <- sql("SELECT map('name', 'Bob')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts an array of maps into a JSON array +#' df2 <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") #' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7abc8720473c1..85a7e0819cff7 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1491,6 +1491,14 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$people), "json"))) expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- sql("SELECT map('name', 'Bob') as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "{\"name\":\"Bob\"}") + + df <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0e76182e0e02d..399bef02d9cc4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1884,9 +1884,9 @@ def json_tuple(col, *fields): @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] - of [[StructType]]s with the specified schema. Returns `null`, in the case of an unparseable - string. + Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` + of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an + unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -1921,10 +1921,12 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a [[StructType]] or [[ArrayType]] of [[StructType]]s into a - JSON string. Throws an exception, in the case of an unsupported type. + Converts a column containing a :class:`StructType`, :class:`ArrayType` of + :class:`StructType`\\s, a :class:`MapType` or :class:`ArrayType` of :class:`MapType`\\s + into a JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct or array of the structs + :param col: name of column containing the struct, array of the structs, the map or + array of the maps. :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row @@ -1937,6 +1939,14 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] + >>> data = [(1, {"name": "Alice"})] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'{"name":"Alice"}')] + >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 134163187b7c6..18b4fed597447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -618,13 +618,13 @@ case class JsonToStructs( {"time":"26/08/2015"} > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); [{"a":1,"b":2}] - > SELECT _FUNC_(map('a',named_struct('b',1))); + > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} - > SELECT _FUNC_(map(named_struct('a',1),named_struct('b',2))); + > SELECT _FUNC_(map(named_struct('a', 1),named_struct('b', 2))); {"[1]":{"b":2}} - > SELECT _FUNC_(map('a',1)); + > SELECT _FUNC_(map('a', 1)); {"a":1} - > SELECT _FUNC_(array((map('a',1)))); + > SELECT _FUNC_(array((map('a', 1)))); [{"a":1}] """, since = "2.2.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index dfe7e28121943..eb06e4f304f0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -43,7 +43,7 @@ private[sql] class JacksonGenerator( private type ValueWriter = (SpecializedGetters, Int) => Unit // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. - require(dataType.isInstanceOf[StructType] | dataType.isInstanceOf[MapType], + require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], "JacksonGenerator only supports to be initialized with a StructType " + s"or MapType but got ${dataType.simpleString}") diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index dcced79d315f3..d9dc728a18e8d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -26,13 +26,13 @@ Extended Usage: {"time":"26/08/2015"} > SELECT to_json(array(named_struct('a', 1, 'b', 2)); [{"a":1,"b":2}] - > SELECT to_json(map('a',named_struct('b',1))); + > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} - > SELECT to_json(map(named_struct('a',1),named_struct('b',2))); + > SELECT to_json(map(named_struct('a', 1),named_struct('b', 2))); {"[1]":{"b":2}} - > SELECT to_json(map('a',1)); + > SELECT to_json(map('a', 1)); {"a":1} - > SELECT to_json(array((map('a',1)))); + > SELECT to_json(array((map('a', 1)))); [{"a":1}] Since: 2.2.0 From 88661747f506e73c79de36711daebb0330de7b0d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 14 Sep 2017 22:32:16 -0700 Subject: [PATCH 02/37] [SPARK-22018][SQL] Preserve top-level alias metadata when collapsing projects ## What changes were proposed in this pull request? If there are two projects like as follows. ``` Project [a_with_metadata#27 AS b#26] +- Project [a#0 AS a_with_metadata#27] +- LocalRelation , [a#0, b#1] ``` Child Project has an output column with a metadata in it, and the parent Project has an alias that implicitly forwards the metadata. So this metadata is visible for higher operators. Upon applying CollapseProject optimizer rule, the metadata is not preserved. ``` Project [a#0 AS b#26] +- LocalRelation , [a#0, b#1] ``` This is incorrect, as downstream operators that expect certain metadata (e.g. watermark in structured streaming) to identify certain fields will fail to do so. This PR fixes it by preserving the metadata of top-level aliases. ## How was this patch tested? New unit test Author: Tathagata Das Closes #19240 from tdas/SPARK-22018. --- .../sql/catalyst/analysis/Analyzer.scala | 5 +++- .../optimizer/CollapseProjectSuite.scala | 23 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0880bd66ea4c4..db276fbc9d53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2256,7 +2256,10 @@ object CleanupAliases extends Rule[LogicalPlan] { def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => - a.withNewChildren(trimAliases(a.child) :: Nil) + a.copy(child = trimAliases(a.child))( + exprId = a.exprId, + qualifier = a.qualifier, + explicitMetadata = Some(a.metadata)) case other => trimAliases(other) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 587437e9aa81d..e7a5bcee420f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.expressions.{Alias, Rand} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.MetadataBuilder class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -119,4 +120,22 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("preserve top-level alias metadata while collapsing projects") { + def hasMetadata(logicalPlan: LogicalPlan): Boolean = { + logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key")) + } + + val metadata = new MetadataBuilder().putLong("key", 1).build() + val analyzed = + Project(Seq(Alias('a_with_metadata, "b")()), + Project(Seq(Alias('a, "a_with_metadata")(explicitMetadata = Some(metadata))), + testRelation.logicalPlan)).analyze + require(hasMetadata(analyzed)) + + val optimized = Optimize.execute(analyzed) + val projects = optimized.collect { case p: Project => p } + assert(projects.size === 1) + assert(hasMetadata(optimized)) + } } From 22b111ef9d10ebf3c285974fd8c5ea0804ca144a Mon Sep 17 00:00:00 2001 From: zhoukang Date: Fri, 15 Sep 2017 14:03:26 +0800 Subject: [PATCH 03/37] [SPARK-21902][CORE] Print root cause for BlockManager#doPut ## What changes were proposed in this pull request? As logging below, actually exception will be hidden when removeBlockInternal throw an exception. `2017-08-31,10:26:57,733 WARN org.apache.spark.storage.BlockManager: Putting block broadcast_110 failed due to an exception 2017-08-31,10:26:57,734 WARN org.apache.spark.broadcast.BroadcastManager: Failed to create a new broadcast in 1 attempts java.io.IOException: Failed to create local dir in /tmp/blockmgr-5bb5ac1e-c494-434a-ab89-bd1808c6b9ed/2e. at org.apache.spark.storage.DiskBlockManager.getFile(DiskBlockManager.scala:70) at org.apache.spark.storage.DiskStore.remove(DiskStore.scala:115) at org.apache.spark.storage.BlockManager.removeBlockInternal(BlockManager.scala:1339) at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:910) at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:948) at org.apache.spark.storage.BlockManager.putIterator(BlockManager.scala:726) at org.apache.spark.storage.BlockManager.putSingle(BlockManager.scala:1233) at org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:122) at org.apache.spark.broadcast.TorrentBroadcast.(TorrentBroadcast.scala:88) at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34) at org.apache.spark.broadcast.BroadcastManager$$anonfun$newBroadcast$1.apply$mcVI$sp(BroadcastManager.scala:60) at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160) at org.apache.spark.broadcast.BroadcastManager.newBroadcast(BroadcastManager.scala:58) at org.apache.spark.SparkContext.broadcast(SparkContext.scala:1415) at org.apache.spark.scheduler.DAGScheduler.submitMissingTasks(DAGScheduler.scala:1002) at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$submitStage(DAGScheduler.scala:924) at org.apache.spark.scheduler.DAGScheduler$$anonfun$submitWaitingChildStages$6.apply(DAGScheduler.scala:771) at org.apache.spark.scheduler.DAGScheduler$$anonfun$submitWaitingChildStages$6.apply(DAGScheduler.scala:770) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at org.apache.spark.scheduler.DAGScheduler.submitWaitingChildStages(DAGScheduler.scala:770) at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:1235) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1662) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1620) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1609) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)` In this pr i will print exception first make troubleshooting more conveniently. PS: This one split from [PR-19133](https://github.com/apache/spark/pull/19133) ## How was this patch tested? Exsist unit test Author: zhoukang Closes #19171 from caneGuy/zhoukang/print-rootcause. --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aaacabe79ace4..b4b5938c307e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -988,11 +988,16 @@ private[spark] class BlockManager( logWarning(s"Putting block $blockId failed") } res + } catch { + // Since removeBlockInternal may throw exception, + // we should print exception first to show root cause. + case NonFatal(e) => + logWarning(s"Putting block $blockId failed due to exception $e.") + throw e } finally { // This cleanup is performed in a finally block rather than a `catch` to avoid having to // catch and properly re-throw InterruptedException. if (exceptionWasThrown) { - logWarning(s"Putting block $blockId failed due to an exception") // If an exception was thrown then it's possible that the code in `putBody` has already // notified the master about the availability of this block, so we need to send an update // to remove this block location. From 4decedfdbd31525e66d33eecc9d38a747c98547b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 14 Sep 2017 23:35:55 -0700 Subject: [PATCH 04/37] [SPARK-22002][SQL] Read JDBC table use custom schema support specify partial fields. ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18266 add a new feature to support read JDBC table use custom schema, but we must specify all the fields. For simplicity, this PR support specify partial fields. ## How was this patch tested? unit tests Author: Yuming Wang Closes #19231 from wangyum/SPARK-22002. --- docs/sql-programming-guide.md | 2 +- .../datasources/jdbc/JdbcUtils.scala | 36 +++++++------- .../datasources/jdbc/JdbcUtilsSuite.scala | 47 ++++++------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 17 ++++--- 4 files changed, 40 insertions(+), 62 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95d704014742c..5db60cc996e75 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1333,7 +1333,7 @@ the following case-insensitive options: customSchema - The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING"). The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. + The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING". You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)". The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 75327f0d38c2e..71133666b3249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -301,12 +301,11 @@ object JdbcUtils extends Logging { } else { rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls } - val metadata = new MetadataBuilder() - .putLong("scale", fieldScale) + val metadata = new MetadataBuilder().putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + fields(i) = StructField(columnName, columnType, nullable) i = i + 1 } new StructType(fields) @@ -768,31 +767,30 @@ object JdbcUtils extends Logging { } /** - * Parses the user specified customSchema option value to DataFrame schema, - * and returns it if it's all columns are equals to default schema's. + * Parses the user specified customSchema option value to DataFrame schema, and + * returns a schema that is replaced by the custom schema's dataType if column name is matched. */ def getCustomSchema( tableSchema: StructType, customSchema: String, nameEquality: Resolver): StructType = { - val userSchema = CatalystSqlParser.parseTableSchema(customSchema) + if (null != customSchema && customSchema.nonEmpty) { + val userSchema = CatalystSqlParser.parseTableSchema(customSchema) - SchemaUtils.checkColumnNameDuplication( - userSchema.map(_.name), "in the customSchema option value", nameEquality) - - val colNames = tableSchema.fieldNames.mkString(",") - val errorMsg = s"Please provide all the columns, all columns are: $colNames" - if (userSchema.size != tableSchema.size) { - throw new AnalysisException(errorMsg) - } + SchemaUtils.checkColumnNameDuplication( + userSchema.map(_.name), "in the customSchema option value", nameEquality) - // This is resolved by names, only check the column names. - userSchema.fieldNames.foreach { col => - tableSchema.find(f => nameEquality(f.name, col)).getOrElse { - throw new AnalysisException(errorMsg) + // This is resolved by names, use the custom filed dataType to replace the default dataType. + val newSchema = tableSchema.map { col => + userSchema.find(f => nameEquality(f.name, col.name)) match { + case Some(c) => col.copy(dataType = c.dataType) + case None => col + } } + StructType(newSchema) + } else { + tableSchema } - userSchema } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala index 1255f262bce94..7d277c1ffaffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -30,57 +30,38 @@ class JdbcUtilsSuite extends SparkFunSuite { val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution test("Parse user specified column types") { - assert( - JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true)))) - assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseSensitive) === - StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true)))) + assert(JdbcUtils.getCustomSchema(tableSchema, null, caseInsensitive) === tableSchema) + assert(JdbcUtils.getCustomSchema(tableSchema, "", caseInsensitive) === tableSchema) + + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseInsensitive) === + StructType(Seq(StructField("C1", DateType, false), StructField("C2", IntegerType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", IntegerType, false)))) + assert( JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true)))) - assert(JdbcUtils.getCustomSchema( - tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DecimalType(38, 0), true), - StructField("C2", StringType, true)))) + StructType(Seq(StructField("C1", DateType, false), StructField("C2", StringType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", StringType, false)))) // Throw AnalysisException val duplicate = intercept[AnalysisException]{ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("c1", StringType, true))) + StructType(Seq(StructField("c1", DateType, false), StructField("c1", StringType, false))) } assert(duplicate.getMessage.contains( "Found duplicate column(s) in the customSchema option value")) - val allColumns = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) === - StructType(Seq(StructField("C1", DateType, true))) - } - assert(allColumns.getMessage.contains("Please provide all the columns,")) - - val caseSensitiveColumnNotFound = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true))) - } - assert(caseSensitiveColumnNotFound.getMessage.contains( - "Please provide all the columns, all columns are: C1,C2;")) - - val caseInsensitiveColumnNotFound = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) - } - assert(caseInsensitiveColumnNotFound.getMessage.contains( - "Please provide all the columns, all columns are: C1,C2;")) - // Throw ParseException val dataTypeNotSupported = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) } assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported")) val mismatchedInput = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) } assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 40179261ab200..689f4106824aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand @@ -970,30 +971,28 @@ class JDBCSuite extends SparkFunSuite test("jdbc API support custom schema") { val parts = Array[String]("THEID < 2", "THEID >= 2") + val customSchema = "NAME STRING, THEID INT" val props = new Properties() - props.put("customSchema", "NAME STRING, THEID BIGINT") - val schema = StructType(Seq( - StructField("NAME", StringType, true), StructField("THEID", LongType, true))) + props.put("customSchema", customSchema) val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props) assert(df.schema.size === 2) - assert(df.schema === schema) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) assert(df.count() === 3) } test("jdbc API custom schema DDL-like strings.") { withTempView("people_view") { + val customSchema = "NAME STRING, THEID INT" sql( s""" |CREATE TEMPORARY VIEW people_view |USING org.apache.spark.sql.jdbc |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass', - |customSchema 'NAME STRING, THEID INT') + |customSchema '$customSchema') """.stripMargin.replaceAll("\n", " ")) - val schema = StructType( - Seq(StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) val df = sql("select * from people_view") - assert(df.schema.size === 2) - assert(df.schema === schema) + assert(df.schema.length === 2) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) assert(df.count() === 3) } } From 3c6198c86ef36e7e5814d74ede00672d0eeb7f32 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Sep 2017 00:47:44 -0700 Subject: [PATCH 05/37] [SPARK-21987][SQL] fix a compatibility issue of sql event logs ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/18600 we removed the `metadata` field from `SparkPlanInfo`. This causes a problem when we replay event logs that are generated by older Spark versions. ## How was this patch tested? a regression test. Author: Wenchen Fan Closes #19237 from cloud-fan/event. --- .../spark/sql/execution/SparkPlanInfo.scala | 3 ++ .../sql/execution/SQLJsonProtocolSuite.scala | 52 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 2118b9118a22f..2a2315896831c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import com.fasterxml.jackson.annotation.JsonIgnoreProperties + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -26,6 +28,7 @@ import org.apache.spark.sql.execution.metric.SQLMetricInfo * Stores information about a SQL SparkPlan. */ @DeveloperApi +@JsonIgnoreProperties(Array("metadata")) // The metadata field was removed in Spark 2.3. class SparkPlanInfo( val nodeName: String, val simpleString: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala new file mode 100644 index 0000000000000..c2e62b987e0cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.json4s.jackson.JsonMethods.parse + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.util.JsonProtocol + +class SQLJsonProtocolSuite extends SparkFunSuite { + + test("SparkPlanGraph backward compatibility: metadata") { + val SQLExecutionStartJsonString = + """ + |{ + | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", + | "executionId":0, + | "description":"test desc", + | "details":"test detail", + | "physicalPlanDescription":"test plan", + | "sparkPlanInfo": { + | "nodeName":"TestNode", + | "simpleString":"test string", + | "children":[], + | "metadata":{}, + | "metrics":[] + | }, + | "time":0 + |} + """.stripMargin + val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) + val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", + new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0) + assert(reconstructedEvent == expectedEvent) + } +} From 79a4dab6297121c075a310a50d0fc0549a3c1e41 Mon Sep 17 00:00:00 2001 From: Travis Hegner Date: Fri, 15 Sep 2017 15:17:16 +0200 Subject: [PATCH 06/37] [SPARK-21958][ML] Word2VecModel save: transform data in the cluster ## What changes were proposed in this pull request? Change a data transformation while saving a Word2VecModel to happen with distributed data instead of local driver data. ## How was this patch tested? Unit tests for the ML sub-component still pass. Running this patch against v2.2.0 in a fully distributed production cluster allows a 4.0G model to save and load correctly, where it would not do so without the patch. Author: Travis Hegner Closes #19191 from travishegner/master. --- .../main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index f6095e26f435c..fe3306e1e50d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -337,14 +337,17 @@ object Word2VecModel extends MLReadable[Word2VecModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val wordVectors = instance.wordVectors.getVectors - val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString val bufferSizeInBytes = Utils.byteStringAsBytes( sc.conf.get("spark.kryoserializer.buffer.max", "64m")) val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) - sparkSession.createDataFrame(dataSeq) + val spark = sparkSession + import spark.implicits._ + spark.createDataset[(String, Array[Float])](wordVectors.toSeq) .repartition(numPartitions) + .map { case (word, vector) => Data(word, vector) } + .toDF() .write .parquet(dataPath) } From c7307acdad881d98857f0b63328fe9c420ddf9c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Sep 2017 22:18:36 +0800 Subject: [PATCH 07/37] [SPARK-15689][SQL] data source v2 read path ## What changes were proposed in this pull request? This PR adds the infrastructure for data source v2, and implement features which Spark already have in data source v1, i.e. column pruning, filter push down, catalyst expression filter push down, InternalRow scan, schema inference, data size report. The write path is excluded to avoid making this PR growing too big, and will be added in follow-up PR. ## How was this patch tested? new tests Author: Wenchen Fan Closes #19136 from cloud-fan/data-source-v2. --- .../spark/sql/sources/v2/DataSourceV2.java | 31 +++ .../sql/sources/v2/DataSourceV2Options.java | 52 ++++ .../spark/sql/sources/v2/ReadSupport.java | 38 +++ .../sql/sources/v2/ReadSupportWithSchema.java | 47 ++++ .../sql/sources/v2/reader/DataReader.java | 40 +++ .../sources/v2/reader/DataSourceV2Reader.java | 67 +++++ .../spark/sql/sources/v2/reader/ReadTask.java | 48 ++++ .../sql/sources/v2/reader/Statistics.java | 32 +++ .../SupportsPushDownCatalystFilters.java | 43 ++++ .../v2/reader/SupportsPushDownFilters.java | 38 +++ .../SupportsPushDownRequiredColumns.java | 42 ++++ .../v2/reader/SupportsReportStatistics.java | 33 +++ .../v2/reader/SupportsScanUnsafeRow.java | 49 ++++ .../apache/spark/sql/DataFrameReader.scala | 46 +++- .../spark/sql/execution/SparkPlanner.scala | 2 + .../datasources/v2/DataSourceRDD.scala | 68 ++++++ .../datasources/v2/DataSourceV2Relation.scala | 40 +++ .../datasources/v2/DataSourceV2ScanExec.scala | 89 +++++++ .../datasources/v2/DataSourceV2Strategy.scala | 93 +++++++ .../sources/v2/JavaAdvancedDataSourceV2.java | 130 ++++++++++ .../v2/JavaSchemaRequiredDataSource.java | 54 +++++ .../sources/v2/JavaSimpleDataSourceV2.java | 86 +++++++ .../sources/v2/JavaUnsafeRowDataSourceV2.java | 88 +++++++ .../sources/v2/DataSourceV2OptionsSuite.scala | 40 +++ .../sql/sources/v2/DataSourceV2Suite.scala | 229 ++++++++++++++++++ 25 files changed, 1518 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java new file mode 100644 index 0000000000000..dbcbe326a7510 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * The base interface for data source v2. Implementations must have a public, no arguments + * constructor. + * + * Note that this is an empty interface, data source implementations should mix-in at least one of + * the plug-in interfaces like {@link ReadSupport}. Otherwise it's just a dummy data source which is + * un-readable/writable. + */ +@InterfaceStability.Evolving +public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java new file mode 100644 index 0000000000000..9a89c8193dd6e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An immutable string-to-string map in which keys are case-insensitive. This is used to represent + * data source options. + */ +@InterfaceStability.Evolving +public class DataSourceV2Options { + private final Map keyLowerCasedMap; + + private String toLowerCase(String key) { + return key.toLowerCase(Locale.ROOT); + } + + public DataSourceV2Options(Map originalMap) { + keyLowerCasedMap = new HashMap<>(originalMap.size()); + for (Map.Entry entry : originalMap.entrySet()) { + keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); + } + } + + /** + * Returns the option value to which the specified key is mapped, case-insensitively. + */ + public Optional get(String key) { + return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java new file mode 100644 index 0000000000000..ab5254a688d5a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability and scan the data from the data source. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * + * @param options the options for this data source reader, which is an immutable case-insensitive + * string-to-string map. + * @return a reader that implements the actual read logic. + */ + DataSourceV2Reader createReader(DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java new file mode 100644 index 0000000000000..c13aeca2ef36f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability and scan the data from the data source. + * + * This is a variant of {@link ReadSupport} that accepts user-specified schema when reading data. + * A data source can implement both {@link ReadSupport} and {@link ReadSupportWithSchema} if it + * supports both schema inference and user-specified schema. + */ +@InterfaceStability.Evolving +public interface ReadSupportWithSchema { + + /** + * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * + * @param schema the full schema of this data source reader. Full schema usually maps to the + * physical schema of the underlying storage of this data source reader, e.g. + * CSV files, JSON files, etc, while this reader may not read data with full + * schema, as column pruning or other optimizations may happen. + * @param options the options for this data source reader, which is an immutable case-insensitive + * string-to-string map. + * @return a reader that implements the actual read logic. + */ + DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java new file mode 100644 index 0000000000000..cfafc1a576793 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Closeable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data + * for a RDD partition. + */ +@InterfaceStability.Evolving +public interface DataReader extends Closeable { + + /** + * Proceed to next record, returns false if there is no more records. + */ + boolean next(); + + /** + * Return the current record. This method should return same value until `next` is called. + */ + T get(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java new file mode 100644 index 0000000000000..48feb049c1de9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.types.StructType; + +/** + * A data source reader that is returned by + * {@link ReadSupport#createReader(DataSourceV2Options)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceV2Options)}. + * It can mix in various query optimization interfaces to speed up the data scan. The actual scan + * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * + * There are mainly 3 kinds of query optimizations: + * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column + * pruning), etc. These push-down interfaces are named like `SupportsPushDownXXX`. + * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. These + * reporting interfaces are named like `SupportsReportingXXX`. + * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. These scan interfaces are named + * like `SupportsScanXXX`. + * + * Spark first applies all operator push-down optimizations that this data source supports. Then + * Spark collects information this data source reported for further optimizations. Finally Spark + * issues the scan request and does the actual data reading. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Reader { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + */ + StructType readSchema(); + + /** + * Returns a list of read tasks. Each task is responsible for outputting data for one RDD + * partition. That means the number of tasks returned here is same as the number of RDD + * partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source reader mixes in other optimization + * interfaces like column pruning, filter push-down, etc. These optimizations are applied before + * Spark issues the scan request. + */ + List> createReadTasks(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java new file mode 100644 index 0000000000000..7885bfcdd49e4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for + * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * + * Note that, the read task will be serialized and sent to executors, then the data reader will be + * created on executors and do the actual reading. + */ +@InterfaceStability.Evolving +public interface ReadTask extends Serializable { + + /** + * The preferred locations where this read task can run faster, but Spark does not guarantee that + * this task will always run on these locations. The implementations should make sure that it can + * be run on any location. The location is a string representing the host name of an executor. + */ + default String[] preferredLocations() { + return new String[0]; + } + + /** + * Returns a data reader to do the actual reading work for this read task. + */ + DataReader createReader(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java new file mode 100644 index 0000000000000..e8cd7adbca071 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.OptionalLong; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent statistics for a data source, which is returned by + * {@link SupportsReportStatistics#getStatistics()}. + */ +@InterfaceStability.Evolving +public interface Statistics { + OptionalLong sizeInBytes(); + OptionalLong numRows(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java new file mode 100644 index 0000000000000..19d706238ec8e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down arbitrary expressions as predicates to the data source. + * This is an experimental and unstable interface as {@link Expression} is not public and may get + * changed in the future Spark versions. + * + * Note that, if data source readers implement both this interface and + * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only + * process this interface. + */ +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsPushDownCatalystFilters { + + /** + * Pushes down filters, and returns unsupported filters. + */ + Expression[] pushCatalystFilters(Expression[] filters); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java new file mode 100644 index 0000000000000..d4b509e7080f2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.Filter; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down filters to the data source and reduce the size of the data to be read. + * + * Note that, if data source readers implement both this interface and + * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process + * {@link SupportsPushDownCatalystFilters}. + */ +@InterfaceStability.Evolving +public interface SupportsPushDownFilters { + + /** + * Pushes down filters, and returns unsupported filters. + */ + Filter[] pushFilters(Filter[] filters); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java new file mode 100644 index 0000000000000..fe0ac8ee0ee32 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down required columns to the data source and only read these columns during + * scan to reduce the size of the data to be read. + */ +@InterfaceStability.Evolving +public interface SupportsPushDownRequiredColumns { + + /** + * Applies column pruning w.r.t. the given requiredSchema. + * + * Implementation should try its best to prune the unnecessary columns or nested fields, but it's + * also OK to do the pruning partially, e.g., a data source may not be able to prune nested + * fields, and only prune top-level columns. + * + * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * applying column pruning. + */ + void pruneColumns(StructType requiredSchema); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java new file mode 100644 index 0000000000000..c019d2f819ab7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report statistics to Spark. + */ +@InterfaceStability.Evolving +public interface SupportsReportStatistics { + + /** + * Returns the basic statistics of this data source. + */ + Statistics getStatistics(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java new file mode 100644 index 0000000000000..829f9a078760b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. + * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get + * changed in the future Spark versions. + */ +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsScanUnsafeRow extends DataSourceV2Reader { + + @Override + default List> createReadTasks() { + throw new IllegalStateException("createReadTasks should not be called with SupportsScanUnsafeRow."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + */ + List> createUnsafeRowReadTasks(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c69acc413e87f..78b668c04fd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -180,13 +182,43 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { "read files of Hive data source directly.") } - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + val dataSource = cls.newInstance() + val options = new DataSourceV2Options(extraOptions.asJava) + + val reader = (cls.newInstance(), userSpecifiedSchema) match { + case (ds: ReadSupportWithSchema, Some(schema)) => + ds.createReader(schema, options) + + case (ds: ReadSupport, None) => + ds.createReader(options) + + case (_: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + + case (ds: ReadSupport, Some(schema)) => + val reader = ds.createReader(options) + if (reader.readSchema() != schema) { + throw new AnalysisException(s"$ds does not allow user-specified schemas.") + } + reader + + case _ => + throw new AnalysisException(s"$cls does not support data reading.") + } + + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } else { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 4e718d609c921..b143d44eae17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf class SparkPlanner( @@ -35,6 +36,7 @@ class SparkPlanner( def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala new file mode 100644 index 0000000000000..b8fe5ac8e3d94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.ReadTask + +class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) + extends Partition with Serializable + +class DataSourceRDD( + sc: SparkContext, + @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readTasks.asScala.zipWithIndex.map { + case (readTask, index) => new DataSourceRDDPartition(index, readTask) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createReader() + context.addTaskCompletionListener(_ => reader.close()) + val iter = new Iterator[UnsafeRow] { + private[this] var valuePrepared = false + + override def hasNext: Boolean = { + if (!valuePrepared) { + valuePrepared = reader.next() + } + valuePrepared + } + + override def next(): UnsafeRow = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + valuePrepared = false + reader.get() + } + } + new InterruptibleIterator(context, iter) + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala new file mode 100644 index 0000000000000..3c9b598fd07c9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} + +case class DataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode { + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } +} + +object DataSourceV2Relation { + def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala new file mode 100644 index 0000000000000..7999c0ceb5749 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.StructType + +case class DataSourceV2ScanExec( + fullOutput: Array[AttributeReference], + @transient reader: DataSourceV2Reader, + // TODO: these 3 parameters are only used to determine the equality of the scan node, however, + // the reader also have this information, and ideally we can just rely on the equality of the + // reader. The only concern is, the reader implementation is outside of Spark and we have no + // control. + readSchema: StructType, + @transient filters: ExpressionSet, + hashPartitionKeys: Seq[String]) extends LeafExecNode { + + def output: Seq[Attribute] = readSchema.map(_.name).map { name => + fullOutput.find(_.name == name).get + } + + override def references: AttributeSet = AttributeSet.empty + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override protected def doExecute(): RDD[InternalRow] = { + val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } + + val inputRDD = new DataSourceRDD(sparkContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } + } +} + +class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) + extends ReadTask[UnsafeRow] { + + override def preferredLocations: Array[String] = rowReadTask.preferredLocations + + override def createReader: DataReader[UnsafeRow] = { + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + } +} + +class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) + extends DataReader[UnsafeRow] { + + override def next: Boolean = rowReader.next + + override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] + + override def close(): Unit = rowReader.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala new file mode 100644 index 0000000000000..b80f695b2a87f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.reader._ + +object DataSourceV2Strategy extends Strategy { + // TODO: write path + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(filters.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => filters + } + + val attrMap = AttributeMap(output.zip(output)) + val projectSet = AttributeSet(projects.flatMap(_.references)) + val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) + + // Match original case of attributes. + // TODO: nested fields pruning + val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) + reader match { + case r: SupportsPushDownRequiredColumns => + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + val scan = DataSourceV2ScanExec( + output.toArray, + reader, + reader.readSchema(), + ExpressionSet(filters), + Nil) + + val filterCondition = stayUpFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withProject = if (projects == withFilter.output) { + withFilter + } else { + ProjectExec(projects, withFilter) + } + + withProject :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java new file mode 100644 index 0000000000000..50900e98dedb6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { + private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + private Filter[] filters = new Filter[0]; + + @Override + public StructType readSchema() { + return requiredSchema; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + this.filters = filters; + return new Filter[0]; + } + + @Override + public List> createReadTasks() { + List> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 4) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 9) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + } + + return res; + } + } + + static class JavaAdvancedReadTask implements ReadTask, DataReader { + private int start; + private int end; + private StructType requiredSchema; + + JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + this.start = start; + this.end = end; + this.requiredSchema = requiredSchema; + } + + @Override + public DataReader createReader() { + return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = start; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -start; + } + } + return new GenericRow(values); + } + + @Override + public void close() throws IOException { + + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java new file mode 100644 index 0000000000000..a174bd8092cbd --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.types.StructType; + +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { + + class Reader implements DataSourceV2Reader { + private final StructType schema; + + Reader(StructType schema) { + this.schema = schema; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Collections.emptyList(); + } + } + + @Override + public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + return new Reader(schema); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java new file mode 100644 index 0000000000000..08469f14c257a --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new JavaSimpleReadTask(0, 5), + new JavaSimpleReadTask(5, 10)); + } + } + + static class JavaSimpleReadTask implements ReadTask, DataReader { + private int start; + private int end; + + JavaSimpleReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createReader() { + return new JavaSimpleReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {start, -start}); + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java new file mode 100644 index 0000000000000..9efe7c791a936 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createUnsafeRowReadTasks() { + return java.util.Arrays.asList( + new JavaUnsafeRowReadTask(0, 5), + new JavaUnsafeRowReadTask(5, 10)); + } + } + + static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + private int start; + private int end; + private UnsafeRow row; + + JavaUnsafeRowReadTask(int start, int end) { + this.start = start; + this.end = end; + this.row = new UnsafeRow(2); + row.pointTo(new byte[8 * 3], 8 * 3); + } + + @Override + public DataReader createReader() { + return new JavaUnsafeRowReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public UnsafeRow get() { + row.setInt(0, start); + row.setInt(1, -start); + return row; + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala new file mode 100644 index 0000000000000..933f4075bcc8a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite + +/** + * A simple test suite to verify `DataSourceV2Options`. + */ +class DataSourceV2OptionsSuite extends SparkFunSuite { + + test("key is case-insensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + assert(options.get("foo").get() == "bar") + assert(options.get("FoO").get() == "bar") + assert(!options.get("abc").isPresent) + } + + test("value is case-sensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + assert(options.get("foo").get == "bAr") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala new file mode 100644 index 0000000000000..9ce93d7ae926c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util.{ArrayList, List => JList} + +import test.org.apache.spark.sql.sources.v2._ + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class DataSourceV2Suite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("simplest implementation") { + Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("advanced implementation") { + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) + checkAnswer(df.select('i).filter('i > 10), Nil) + } + } + } + + test("unsafe row implementation") { + Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("schema required data source") { + Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => + withClue(cls.getName) { + val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) + assert(e.message.contains("A schema needs to be specified")) + + val schema = new StructType().add("i", "int").add("s", "string") + val df = spark.read.format(cls.getName).schema(schema).load() + + assert(df.schema == schema) + assert(df.collect().isEmpty) + } + } + } +} + +class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { + private var current = start - 1 + + override def createReader(): DataReader[Row] = new SimpleReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = Row(current, -current) + + override def close(): Unit = {} +} + + + +class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.filters = filters + Array.empty + } + + override def readSchema(): StructType = { + requiredSchema + } + + override def createReadTasks(): JList[ReadTask[Row]] = { + val lowerBound = filters.collect { + case GreaterThan("i", v: Int) => v + }.headOption + + val res = new ArrayList[ReadTask[Row]] + + if (lowerBound.isEmpty) { + res.add(new AdvancedReadTask(0, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 4) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 9) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + } + + res + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) + extends ReadTask[Row] with DataReader[Row] { + + private var current = start - 1 + + override def createReader(): DataReader[Row] = new AdvancedReadTask(start, end, requiredSchema) + + override def close(): Unit = {} + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + Row.fromSeq(values) + } +} + + +class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class UnsafeRowReadTask(start: Int, end: Int) + extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { + + private val row = new UnsafeRow(2) + row.pointTo(new Array[Byte](8 * 3), 8 * 3) + + private var current = start - 1 + + override def createReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + override def get(): UnsafeRow = { + row.setInt(0, current) + row.setInt(1, -current) + row + } + + override def close(): Unit = {} +} + +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { + + class Reader(val readSchema: StructType) extends DataSourceV2Reader { + override def createReadTasks(): JList[ReadTask[Row]] = + java.util.Collections.emptyList() + } + + override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + new Reader(schema) +} From 0bad10d3e36d3238c7ee7c0fc5465072734b3ae4 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 15 Sep 2017 21:10:07 -0700 Subject: [PATCH 08/37] [SPARK-22017] Take minimum of all watermark execs in StreamExecution. ## What changes were proposed in this pull request? Take the minimum of all watermark exec nodes as the "real" watermark in StreamExecution, rather than picking one arbitrarily. ## How was this patch tested? new unit test Author: Jose Torres Closes #19239 from joseph-torres/SPARK-22017. --- .../streaming/IncrementalExecution.scala | 2 +- .../execution/streaming/StreamExecution.scala | 39 ++++++++-- .../streaming/EventTimeWatermarkSuite.scala | 78 +++++++++++++++++++ 3 files changed, 113 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 258a64216136f..19d95980d57d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -39,7 +39,7 @@ class IncrementalExecution( val checkpointLocation: String, val runId: UUID, val currentBatchId: Long, - offsetSeqMetadata: OffsetSeqMetadata) + val offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { // Modified planner with stateful operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 952e431fb19d3..b27a59b8a34fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -130,6 +130,16 @@ class StreamExecution( protected var offsetSeqMetadata = OffsetSeqMetadata( batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) + /** + * A map of current watermarks, keyed by the position of the watermark operator in the + * physical plan. + * + * This state is 'soft state', which does not affect the correctness and semantics of watermarks + * and is not persisted across query restarts. + * The fault-tolerant watermark state is in offsetSeqMetadata. + */ + protected val watermarkMsMap: MutableMap[Int, Long] = MutableMap() + override val id: UUID = UUID.fromString(streamMetadata.id) override val runId: UUID = UUID.randomUUID @@ -560,13 +570,32 @@ class StreamExecution( } if (hasNewData) { var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermark if we find one in the plan. + // Update the eventTime watermarks if we find any in the plan. if (lastExecution != null) { lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") - e.eventTimeStats.value.max - e.delayMs - }.headOption.foreach { newWatermarkMs => + case e: EventTimeWatermarkExec => e + }.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = watermarkMsMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + watermarkMsMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!watermarkMsMap.isDefinedAt(index)) { + watermarkMsMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + if(!watermarkMsMap.isEmpty) { + val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 if (newWatermarkMs > batchWatermarkMs) { logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") batchWatermarkMs = newWatermarkMs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 4f19fa0bb4a97..f3e8cf950a5a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -300,6 +300,84 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("watermark with 2 streams") { + import org.apache.spark.sql.functions.sum + val first = MemoryStream[Int] + + val firstDf = first.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .select('value) + + val second = MemoryStream[Int] + + val secondDf = second.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "5 seconds") + .select('value) + + withTempDir { checkpointDir => + val unionWriter = firstDf.union(secondDf).agg(sum('value)) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("memory") + .outputMode("complete") + .queryName("test") + + val union = unionWriter.start() + + def getWatermarkAfterData( + firstData: Seq[Int] = Seq.empty, + secondData: Seq[Int] = Seq.empty, + query: StreamingQuery = union): Long = { + if (firstData.nonEmpty) first.addData(firstData) + if (secondData.nonEmpty) second.addData(secondData) + query.processAllAvailable() + // add a dummy batch so lastExecution has the new watermark + first.addData(0) + query.processAllAvailable() + // get last watermark + val lastExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution + lastExecution.offsetSeqMetadata.batchWatermarkMs + } + + // Global watermark starts at 0 until we get data from both sides + assert(getWatermarkAfterData(firstData = Seq(11)) == 0) + assert(getWatermarkAfterData(secondData = Seq(6)) == 1000) + // Global watermark stays at left watermark 1 when right watermark moves to 2 + assert(getWatermarkAfterData(secondData = Seq(8)) == 1000) + // Global watermark switches to right side value 2 when left watermark goes higher + assert(getWatermarkAfterData(firstData = Seq(21)) == 3000) + // Global watermark goes back to left + assert(getWatermarkAfterData(secondData = Seq(17, 28, 39)) == 11000) + // Global watermark stays on left as long as it's below right + assert(getWatermarkAfterData(firstData = Seq(31)) == 21000) + assert(getWatermarkAfterData(firstData = Seq(41)) == 31000) + // Global watermark switches back to right again + assert(getWatermarkAfterData(firstData = Seq(51)) == 34000) + + // Global watermark is updated correctly with simultaneous data from both sides + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 90000) + assert(getWatermarkAfterData(firstData = Seq(120), secondData = Seq(110)) == 105000) + assert(getWatermarkAfterData(firstData = Seq(130), secondData = Seq(125)) == 120000) + + // Global watermark doesn't decrement with simultaneous data + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(140), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(135)) == 130000) + + // Global watermark recovers after restart, but left side watermark ahead of it does not. + assert(getWatermarkAfterData(firstData = Seq(200), secondData = Seq(190)) == 185000) + union.stop() + val union2 = unionWriter.start() + assert(getWatermarkAfterData(query = union2) == 185000) + // Even though the left side was ahead of 185000 in the last execution, the watermark won't + // increment until it gets past it in this execution. + assert(getWatermarkAfterData(secondData = Seq(200), query = union2) == 185000) + assert(getWatermarkAfterData(firstData = Seq(200), query = union2) == 190000) + } + } + test("complete mode") { val inputData = MemoryStream[Int] From 73d9067226671adb6410ccfb4d5ca2f00283c82b Mon Sep 17 00:00:00 2001 From: Armin Date: Sat, 16 Sep 2017 09:18:13 +0100 Subject: [PATCH 09/37] [SPARK-21967][CORE] org.apache.spark.unsafe.types.UTF8String#compareTo Should Compare 8 Bytes at a Time for Better Performance ## What changes were proposed in this pull request? * Using 64 bit unsigned long comparison instead of unsigned int comparison in `org.apache.spark.unsafe.types.UTF8String#compareTo` for better performance. * Making `IS_LITTLE_ENDIAN` a constant for correctness reasons (shouldn't use a non-constant in `compareTo` implementations and it def. is a constant per JVM) ## How was this patch tested? Build passes and the functionality is widely covered by existing tests as far as I can see. Author: Armin Closes #19180 from original-brownbear/SPARK-21967. --- .../apache/spark/unsafe/types/UTF8String.java | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 43f57672d9544..dd67f15749add 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -64,7 +64,8 @@ public final class UTF8String implements Comparable, Externalizable, 5, 5, 5, 5, 6, 6}; - private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + private static final boolean IS_LITTLE_ENDIAN = + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); @@ -220,7 +221,7 @@ public long getPrefix() { // After getting the data, we use a mask to mask out data that is not part of the string. long p; long mask = 0; - if (isLittleEndian) { + if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { p = Platform.getLong(base, offset); } else if (numBytes > 4) { @@ -1097,10 +1098,23 @@ public UTF8String copy() { @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); - // TODO: compare 8 bytes as unsigned long - for (int i = 0; i < len; i ++) { + int wordMax = (len / 8) * 8; + long roffset = other.offset; + Object rbase = other.base; + for (int i = 0; i < wordMax; i += 8) { + long left = getLong(base, offset + i); + long right = getLong(rbase, roffset + i); + if (left != right) { + if (IS_LITTLE_ENDIAN) { + return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); + } else { + return Long.compareUnsigned(left, right); + } + } + } + for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); + int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); if (res != 0) { return res; } From f4073020adf9752c7d7b39631ec3fa36d6345902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Bry=C5=84ski?= Date: Mon, 18 Sep 2017 02:34:44 +0900 Subject: [PATCH 10/37] [SPARK-22032][PYSPARK] Speed up StructType conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? StructType.fromInternal is calling f.fromInternal(v) for every field. We can use precalculated information about type to limit the number of function calls. (its calculated once per StructType and used in per record calculations) Benchmarks (Python profiler) ``` df = spark.range(10000000).selectExpr("id as id0", "id as id1", "id as id2", "id as id3", "id as id4", "id as id5", "id as id6", "id as id7", "id as id8", "id as id9", "struct(id) as s").cache() df.count() df.rdd.map(lambda x: x).count() ``` Before ``` 310274584 function calls (300272456 primitive calls) in 1320.684 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 10000000 253.417 0.000 486.991 0.000 types.py:619() 30000000 192.272 0.000 1009.986 0.000 types.py:612(fromInternal) 100000000 176.140 0.000 176.140 0.000 types.py:88(fromInternal) 20000000 156.832 0.000 328.093 0.000 types.py:1471(_create_row) 14000 107.206 0.008 1237.917 0.088 {built-in method loads} 20000000 80.176 0.000 1090.162 0.000 types.py:1468() ``` After ``` 210274584 function calls (200272456 primitive calls) in 1035.974 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 30000000 215.845 0.000 698.748 0.000 types.py:612(fromInternal) 20000000 165.042 0.000 351.572 0.000 types.py:1471(_create_row) 14000 116.834 0.008 946.791 0.068 {built-in method loads} 20000000 87.326 0.000 786.073 0.000 types.py:1468() 20000000 85.477 0.000 134.607 0.000 types.py:1519(__new__) 10000000 65.777 0.000 126.712 0.000 types.py:619() ``` Main difference is types.py:619() and types.py:88(fromInternal) (which is removed in After) The number of function calls is 100 million less. And performance is 20% better. Benchmark (worst case scenario.) Test ``` df = spark.range(1000000).selectExpr("current_timestamp as id0", "current_timestamp as id1", "current_timestamp as id2", "current_timestamp as id3", "current_timestamp as id4", "current_timestamp as id5", "current_timestamp as id6", "current_timestamp as id7", "current_timestamp as id8", "current_timestamp as id9").cache() df.count() df.rdd.map(lambda x: x).count() ``` Before ``` 31166064 function calls (31163984 primitive calls) in 150.882 seconds ``` After ``` 31166064 function calls (31163984 primitive calls) in 153.220 seconds ``` IMPORTANT: The benchmark was done on top of https://github.com/apache/spark/pull/19246. Without https://github.com/apache/spark/pull/19246 the performance improvement will be even greater. ## How was this patch tested? Existing tests. Performance benchmark. Author: Maciej Bryński Closes #19249 from maver1ck/spark_22032. --- python/pyspark/sql/types.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 920cf009f599d..aaf520fa8019f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -483,7 +483,9 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self) + # Precalculated list of fields that need conversion with fromInternal/toInternal functions + self._needConversion = [f.needConversion() for f in self] + self._needSerializeAnyField = any(self._needConversion) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -528,7 +530,9 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self) + # Precalculated list of fields that need conversion with fromInternal/toInternal functions + self._needConversion = [f.needConversion() for f in self] + self._needSerializeAnyField = any(self._needConversion) return self def __iter__(self): @@ -590,13 +594,17 @@ def toInternal(self, obj): return if self._needSerializeAnyField: + # Only calling toInternal function for fields that need conversion if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) + return tuple(f.toInternal(obj.get(n)) if c else obj.get(n) + for n, f, c in zip(self.names, self.fields, self._needConversion)) elif isinstance(obj, (tuple, list)): - return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + return tuple(f.toInternal(v) if c else v + for f, v, c in zip(self.fields, obj, self._needConversion)) elif hasattr(obj, "__dict__"): d = obj.__dict__ - return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) + return tuple(f.toInternal(d.get(n)) if c else d.get(n) + for n, f, c in zip(self.names, self.fields, self._needConversion)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -619,7 +627,9 @@ def fromInternal(self, obj): # it's already converted by pickler return obj if self._needSerializeAnyField: - values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + # Only calling fromInternal function for fields that need conversion + values = [f.fromInternal(v) if c else v + for f, v, c in zip(self.fields, obj, self._needConversion)] else: values = obj return _create_row(self.names, values) From 6adf67dd14b0ece342bb91adf800df0a7101e038 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 18 Sep 2017 02:46:27 +0900 Subject: [PATCH 11/37] [SPARK-21985][PYSPARK] PairDeserializer is broken for double-zipped RDDs ## What changes were proposed in this pull request? (edited) Fixes a bug introduced in #16121 In PairDeserializer convert each batch of keys and values to lists (if they do not have `__len__` already) so that we can check that they are the same size. Normally they already are lists so this should not have a performance impact, but this is needed when repeated `zip`'s are done. ## How was this patch tested? Additional unit test Author: Andrew Ray Closes #19226 from aray/SPARK-21985. --- python/pyspark/serializers.py | 6 +++++- python/pyspark/tests.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..660b19ad2a7c4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -97,7 +97,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ - Return an iterator of deserialized batches (lists) of objects from the input stream. + Return an iterator of deserialized batches (iterable) of objects from the input stream. if the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ @@ -343,6 +343,10 @@ def _load_stream_without_unbatching(self, stream): key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + # For double-zipped RDDs, the batches can be iterators from other PairDeserializer, + # instead of lists. We need to convert them to lists if needed. + key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch) + val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch) if len(key_batch) != len(val_batch): raise ValueError("Can not deserialize PairRDD with different number of items" " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 000dd1eb8e481..3c108ec92ccc9 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -644,6 +644,18 @@ def test_cartesian_chaining(self): set([(x, (y, y)) for x in range(10) for y in range(10)]) ) + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) From 6308c65f08b507408033da1f1658144ea8c1491f Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Mon, 18 Sep 2017 10:42:24 +0800 Subject: [PATCH 12/37] [SPARK-21953] Show both memory and disk bytes spilled if either is present As written now, there must be both memory and disk bytes spilled to show either of them. If there is only one of those types of spill recorded, it will be hidden. Author: Andrew Ash Closes #19164 from ash211/patch-3. --- core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index d9c87f69d8a54..5acec0d0f54c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -110,7 +110,7 @@ private[spark] object UIData { def hasOutput: Boolean = outputBytes > 0 def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 - def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0 } /** From 7c7266208a3be984ac1ce53747dc0c3640f4ecac Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 18 Sep 2017 13:20:11 +0900 Subject: [PATCH 13/37] [SPARK-22043][PYTHON] Improves error message for show_profiles and dump_profiles ## What changes were proposed in this pull request? This PR proposes to improve error message from: ``` >>> sc.show_profiles() Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/context.py", line 1000, in show_profiles self.profiler_collector.show_profiles() AttributeError: 'NoneType' object has no attribute 'show_profiles' >>> sc.dump_profiles("/tmp/abc") Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/context.py", line 1005, in dump_profiles self.profiler_collector.dump_profiles(path) AttributeError: 'NoneType' object has no attribute 'dump_profiles' ``` to ``` >>> sc.show_profiles() Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/context.py", line 1003, in show_profiles raise RuntimeError("'spark.python.profile' configuration must be set " RuntimeError: 'spark.python.profile' configuration must be set to 'true' to enable Python profile. >>> sc.dump_profiles("/tmp/abc") Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/context.py", line 1012, in dump_profiles raise RuntimeError("'spark.python.profile' configuration must be set " RuntimeError: 'spark.python.profile' configuration must be set to 'true' to enable Python profile. ``` ## How was this patch tested? Unit tests added in `python/pyspark/tests.py` and manual tests. Author: hyukjinkwon Closes #19260 from HyukjinKwon/profile-errors. --- python/pyspark/context.py | 12 ++++++++++-- python/pyspark/tests.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7046043e0376..a33f6dcf31fc0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -997,12 +997,20 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): def show_profiles(self): """ Print the profile stats to stdout """ - self.profiler_collector.show_profiles() + if self.profiler_collector is not None: + self.profiler_collector.show_profiles() + else: + raise RuntimeError("'spark.python.profile' configuration must be set " + "to 'true' to enable Python profile.") def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - self.profiler_collector.dump_profiles(path) + if self.profiler_collector is not None: + self.profiler_collector.dump_profiles(path) + else: + raise RuntimeError("'spark.python.profile' configuration must be set " + "to 'true' to enable Python profile.") def getConf(self): conf = SparkConf() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3c108ec92ccc9..da99872da2f0e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1296,6 +1296,22 @@ def heavy_foo(x): rdd.foreach(heavy_foo) +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + class InputFormatTests(ReusedPySparkTestCase): @classmethod From 1e978b17d63d7ba20368057aa4e65f5ef6e87369 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Sun, 17 Sep 2017 23:15:08 -0700 Subject: [PATCH 14/37] =?UTF-8?q?[SPARK-21113][CORE]=20Read=20ahead=20inpu?= =?UTF-8?q?t=20stream=20to=20amortize=20disk=20IO=20cost=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Profiling some of our big jobs, we see that around 30% of the time is being spent in reading the spill files from disk. In order to amortize the disk IO cost, the idea is to implement a read ahead input stream which asynchronously reads ahead from the underlying input stream when specified amount of data has been read from the current buffer. It does it by maintaining two buffer - active buffer and read ahead buffer. The active buffer contains data which should be returned when a read() call is issued. The read-ahead buffer is used to asynchronously read from the underlying input stream and once the active buffer is exhausted, we flip the two buffers so that we can start reading from the read ahead buffer without being blocked in disk I/O. ## How was this patch tested? Tested by running a job on the cluster and could see up to 8% CPU improvement. Author: Sital Kedia Author: Shixiong Zhu Author: Sital Kedia Closes #18317 from sitalkedia/read_ahead_buffer. --- .../apache/spark/io/ReadAheadInputStream.java | 408 ++++++++++++++++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 21 +- ....java => GenericFileInputStreamSuite.java} | 13 +- .../spark/io/NioBufferedInputStreamSuite.java | 33 ++ .../spark/io/ReadAheadInputStreamSuite.java | 33 ++ 5 files changed, 495 insertions(+), 13 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java rename core/src/test/java/org/apache/spark/io/{NioBufferedFileInputStreamSuite.java => GenericFileInputStreamSuite.java} (87%) create mode 100644 core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java create mode 100644 core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java new file mode 100644 index 0000000000000..618bd42d0e65d --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -0,0 +1,408 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.apache.spark.util.ThreadUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +/** + * {@link InputStream} implementation which asynchronously reads ahead from the underlying input + * stream when specified amount of data has been read from the current buffer. It does it by maintaining + * two buffer - active buffer and read ahead buffer. Active buffer contains data which should be returned + * when a read() call is issued. The read ahead buffer is used to asynchronously read from the underlying + * input stream and once the current active buffer is exhausted, we flip the two buffers so that we can + * start reading from the read ahead buffer without being blocked in disk I/O. + */ +public class ReadAheadInputStream extends InputStream { + + private static final Logger logger = LoggerFactory.getLogger(ReadAheadInputStream.class); + + private ReentrantLock stateChangeLock = new ReentrantLock(); + + @GuardedBy("stateChangeLock") + private ByteBuffer activeBuffer; + + @GuardedBy("stateChangeLock") + private ByteBuffer readAheadBuffer; + + @GuardedBy("stateChangeLock") + private boolean endOfStream; + + @GuardedBy("stateChangeLock") + // true if async read is in progress + private boolean readInProgress; + + @GuardedBy("stateChangeLock") + // true if read is aborted due to an exception in reading from underlying input stream. + private boolean readAborted; + + @GuardedBy("stateChangeLock") + private Throwable readException; + + @GuardedBy("stateChangeLock") + // whether the close method is called. + private boolean isClosed; + + @GuardedBy("stateChangeLock") + // true when the close method will close the underlying input stream. This is valid only if + // `isClosed` is true. + private boolean isUnderlyingInputStreamBeingClosed; + + @GuardedBy("stateChangeLock") + // whether there is a read ahead task running, + private boolean isReading; + + // If the remaining data size in the current buffer is below this threshold, + // we issue an async read from the underlying input stream. + private final int readAheadThresholdInBytes; + + private final InputStream underlyingInputStream; + + private final ExecutorService executorService = ThreadUtils.newDaemonSingleThreadExecutor("read-ahead"); + + private final Condition asyncReadComplete = stateChangeLock.newCondition(); + + private static final ThreadLocal oneByte = ThreadLocal.withInitial(() -> new byte[1]); + + /** + * Creates a ReadAheadInputStream with the specified buffer size and read-ahead + * threshold + * + * @param inputStream The underlying input stream. + * @param bufferSizeInBytes The buffer size. + * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead + * threshold, an async read is triggered. + */ + public ReadAheadInputStream(InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + Preconditions.checkArgument(bufferSizeInBytes > 0, + "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); + Preconditions.checkArgument(readAheadThresholdInBytes > 0 && + readAheadThresholdInBytes < bufferSizeInBytes, + "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, but the" + + "value is " + readAheadThresholdInBytes); + activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); + readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); + this.readAheadThresholdInBytes = readAheadThresholdInBytes; + this.underlyingInputStream = inputStream; + activeBuffer.flip(); + readAheadBuffer.flip(); + } + + private boolean isEndOfStream() { + return (!activeBuffer.hasRemaining() && !readAheadBuffer.hasRemaining() && endOfStream); + } + + private void checkReadException() throws IOException { + if (readAborted) { + Throwables.propagateIfPossible(readException, IOException.class); + throw new IOException(readException); + } + } + + /** Read data from underlyingInputStream to readAheadBuffer asynchronously. */ + private void readAsync() throws IOException { + stateChangeLock.lock(); + final byte[] arr = readAheadBuffer.array(); + try { + if (endOfStream || readInProgress) { + return; + } + checkReadException(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + readInProgress = true; + } finally { + stateChangeLock.unlock(); + } + executorService.execute(new Runnable() { + + @Override + public void run() { + stateChangeLock.lock(); + try { + if (isClosed) { + readInProgress = false; + return; + } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } + + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + Throwable exception = null; + try { + while (true) { + read = underlyingInputStream.read(arr); + if (0 != read) break; + } + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; + } + } finally { + stateChangeLock.lock(); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } else { + readAheadBuffer.limit(read); + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); + } + } + }); + } + + private void closeUnderlyingInputStreamIfNecessary() { + boolean needToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + isReading = false; + if (isClosed && !isUnderlyingInputStreamBeingClosed) { + // close method cannot close underlyingInputStream because we were reading. + needToCloseUnderlyingInputStream = true; + } + } finally { + stateChangeLock.unlock(); + } + if (needToCloseUnderlyingInputStream) { + try { + underlyingInputStream.close(); + } catch (IOException e) { + logger.warn(e.getMessage(), e); + } + } + } + + private void signalAsyncReadComplete() { + stateChangeLock.lock(); + try { + asyncReadComplete.signalAll(); + } finally { + stateChangeLock.unlock(); + } + } + + private void waitForAsyncReadComplete() throws IOException { + stateChangeLock.lock(); + try { + while (readInProgress) { + asyncReadComplete.await(); + } + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + stateChangeLock.unlock(); + } + checkReadException(); + } + + @Override + public int read() throws IOException { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } + + @Override + public int read(byte[] b, int offset, int len) throws IOException { + if (offset < 0 || len < 0 || len > b.length - offset) { + throw new IndexOutOfBoundsException(); + } + if (len == 0) { + return 0; + } + stateChangeLock.lock(); + try { + return readInternal(b, offset, len); + } finally { + stateChangeLock.unlock(); + } + } + + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + + /** + * Internal read function which should be called only from read() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private int readInternal(byte[] b, int offset, int len) throws IOException { + assert (stateChangeLock.isLocked()); + if (!activeBuffer.hasRemaining()) { + waitForAsyncReadComplete(); + if (readAheadBuffer.hasRemaining()) { + swapBuffers(); + } else { + // The first read or activeBuffer is skipped. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } + swapBuffers(); + } + } else { + checkReadException(); + } + len = Math.min(len, activeBuffer.remaining()); + activeBuffer.get(b, offset, len); + + if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return len; + } + + @Override + public int available() throws IOException { + stateChangeLock.lock(); + // Make sure we have no integer overflow. + try { + return (int) Math.min((long) Integer.MAX_VALUE, + (long) activeBuffer.remaining() + readAheadBuffer.remaining()); + } finally { + stateChangeLock.unlock(); + } + } + + @Override + public long skip(long n) throws IOException { + if (n <= 0L) { + return 0L; + } + stateChangeLock.lock(); + long skipped; + try { + skipped = skipInternal(n); + } finally { + stateChangeLock.unlock(); + } + return skipped; + } + + /** + * Internal skip function which should be called only from skip() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private long skipInternal(long n) throws IOException { + assert (stateChangeLock.isLocked()); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return 0; + } + if (available() >= n) { + // we can skip from the internal buffers + int toSkip = (int) n; + if (toSkip <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position(toSkip + activeBuffer.position()); + if (activeBuffer.remaining() <= readAheadThresholdInBytes + && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return n; + } + // We need to skip from both active buffer and read ahead buffer + toSkip -= activeBuffer.remaining(); + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(toSkip + readAheadBuffer.position()); + swapBuffers(); + readAsync(); + return n; + } else { + int skippedBytes = available(); + long toSkip = n - skippedBytes; + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + long skippedFromInputStream = underlyingInputStream.skip(toSkip); + readAsync(); + return skippedBytes + skippedFromInputStream; + } + } + + @Override + public void close() throws IOException { + boolean isSafeToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + if (isClosed) { + return; + } + isClosed = true; + if (!isReading) { + // Nobody is reading, so we can close the underlying input stream in this method. + isSafeToCloseUnderlyingInputStream = true; + // Flip this to make sure the read ahead task will not close the underlying input stream. + isUnderlyingInputStreamBeingClosed = true; + } + } finally { + stateChangeLock.unlock(); + } + + try { + executorService.shutdownNow(); + executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + if (isSafeToCloseUnderlyingInputStream) { + underlyingInputStream.close(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9521ab86a12d5..1e760b0b51988 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -17,20 +17,20 @@ package org.apache.spark.util.collection.unsafe.sort; -import java.io.*; - import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; - import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.io.NioBufferedFileInputStream; +import org.apache.spark.io.ReadAheadInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.*; + /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). @@ -72,10 +72,23 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } + final double readAheadFraction = + SparkEnv.get() == null ? 0.5 : + SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + + final boolean readAheadEnabled = + SparkEnv.get() == null ? false : + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { - this.in = serializerManager.wrapStream(blockId, bs); + if (readAheadEnabled) { + this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), + (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + } else { + this.in = serializerManager.wrapStream(blockId, bs); + } this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java rename to core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 2c1a34a607592..3440e1aea2f46 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -31,11 +31,13 @@ /** * Tests functionality of {@link NioBufferedFileInputStream} */ -public class NioBufferedFileInputStreamSuite { +public abstract class GenericFileInputStreamSuite { private byte[] randomBytes; - private File inputFile; + protected File inputFile; + + protected InputStream inputStream; @Before public void setUp() throws IOException { @@ -52,7 +54,6 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -60,7 +61,6 @@ public void testReadOneByte() throws IOException { @Test public void testReadMultipleBytes() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); byte[] readBytes = new byte[8 * 1024]; int i = 0; while (i < randomBytes.length) { @@ -74,7 +74,6 @@ public void testReadMultipleBytes() throws IOException { @Test public void testBytesSkipped() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(1024, inputStream.skip(1024)); for (int i = 1024; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); @@ -83,7 +82,6 @@ public void testBytesSkipped() throws IOException { @Test public void testBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -95,7 +93,6 @@ public void testBytesSkippedAfterRead() throws IOException { @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -111,7 +108,6 @@ public void testNegativeBytesSkippedAfterRead() throws IOException { @Test public void testSkipFromFileChannel() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10); // Since the buffer is smaller than the skipped bytes, this will guarantee // we skip from underlying file channel. assertEquals(1024, inputStream.skip(1024)); @@ -128,7 +124,6 @@ public void testSkipFromFileChannel() throws IOException { @Test public void testBytesSkippedAfterEOF() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); assertEquals(-1, inputStream.read()); } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java new file mode 100644 index 0000000000000..211b33a1a9fb0 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.junit.Before; + +import java.io.IOException; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { + + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new NioBufferedFileInputStream(inputFile); + } +} diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java new file mode 100644 index 0000000000000..5008f93b7e409 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.junit.Before; + +import java.io.IOException; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { + + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + } +} From 894a7561de2c2ff01fe7fcc5268378161e9e5643 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Sep 2017 16:42:08 +0800 Subject: [PATCH 15/37] [SPARK-22047][TEST] ignore HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? As reported in https://issues.apache.org/jira/browse/SPARK-22047 , HiveExternalCatalogVersionsSuite is failing frequently, let's disable this test suite to unblock other PRs, I'm looking into the root cause. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19264 from cloud-fan/test. --- .../apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 2928a734a7e36..01db9eb6f04f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.Utils * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the * downloading for this spark version. */ +@org.scalatest.Ignore class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") From 3b049abf102908ca72674139367e3b8d9ffcc283 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 18 Sep 2017 08:49:32 -0700 Subject: [PATCH 16/37] [SPARK-22003][SQL] support array column in vectorized reader with UDF ## What changes were proposed in this pull request? The UDF needs to deserialize the `UnsafeRow`. When the column type is Array, the `get` method from the `ColumnVector`, which is used by the vectorized reader, is called, but this method is not implemented. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #19230 from liufengdb/fix_array_open. --- .../execution/vectorized/ColumnVector.java | 103 ++++----- .../vectorized/ColumnVectorSuite.scala | 201 ++++++++++++++++++ 2 files changed, 242 insertions(+), 62 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a69dd9718fe33..c4b519f0b153f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -100,72 +100,16 @@ public ArrayData copy() { public Object[] array() { DataType dt = data.dataType(); Object[] list = new Object[length]; - - if (dt instanceof BooleanType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getBoolean(offset + i); - } - } - } else if (dt instanceof ByteType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getByte(offset + i); - } - } - } else if (dt instanceof ShortType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getShort(offset + i); - } - } - } else if (dt instanceof IntegerType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getInt(offset + i); - } - } - } else if (dt instanceof FloatType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getFloat(offset + i); - } - } - } else if (dt instanceof DoubleType) { + try { for (int i = 0; i < length; i++) { if (!data.isNullAt(offset + i)) { - list[i] = data.getDouble(offset + i); + list[i] = get(i, dt); } } - } else if (dt instanceof LongType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getLong(offset + i); - } - } - } else if (dt instanceof DecimalType) { - DecimalType decType = (DecimalType)dt; - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getDecimal(i, decType.precision(), decType.scale()); - } - } - } else if (dt instanceof StringType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getUTF8String(i).toString(); - } - } - } else if (dt instanceof CalendarIntervalType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getInterval(i); - } - } - } else { - throw new UnsupportedOperationException("Type " + dt); + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); } - return list; } @Override @@ -237,7 +181,42 @@ public MapData getMap(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - throw new UnsupportedOperationException(); + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala new file mode 100644 index 0000000000000..998067a47033b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { + + var testVector: WritableColumnVector = _ + + private def allocate(capacity: Int, dt: DataType): WritableColumnVector = { + new OnHeapColumnVector(capacity, dt) + } + + override def afterEach(): Unit = { + testVector.close() + } + + test("boolean") { + testVector = allocate(10, BooleanType) + (0 until 10).foreach { i => + testVector.appendBoolean(i % 2 == 0) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, BooleanType) === (i % 2 == 0)) + } + } + + test("byte") { + testVector = allocate(10, ByteType) + (0 until 10).foreach { i => + testVector.appendByte(i.toByte) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ByteType) === (i.toByte)) + } + } + + test("short") { + testVector = allocate(10, ShortType) + (0 until 10).foreach { i => + testVector.appendShort(i.toShort) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ShortType) === (i.toShort)) + } + } + + test("int") { + testVector = allocate(10, IntegerType) + (0 until 10).foreach { i => + testVector.appendInt(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, IntegerType) === i) + } + } + + test("long") { + testVector = allocate(10, LongType) + (0 until 10).foreach { i => + testVector.appendLong(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, LongType) === i) + } + } + + test("float") { + testVector = allocate(10, FloatType) + (0 until 10).foreach { i => + testVector.appendFloat(i.toFloat) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, FloatType) === i.toFloat) + } + } + + test("double") { + testVector = allocate(10, DoubleType) + (0 until 10).foreach { i => + testVector.appendDouble(i.toDouble) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, DoubleType) === i.toDouble) + } + } + + test("string") { + testVector = allocate(10, StringType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) + } + } + + test("binary") { + testVector = allocate(10, BinaryType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + assert(array.get(i, BinaryType) === utf8) + } + } + + test("array") { + val arrayType = ArrayType(IntegerType, true) + testVector = allocate(10, arrayType) + + val data = testVector.arrayData() + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + testVector.putArray(0, 0, 1) + testVector.putArray(1, 1, 2) + testVector.putArray(2, 3, 0) + testVector.putArray(3, 3, 3) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) + assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) + assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) + assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + } + + test("struct") { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + testVector = allocate(10, schema) + val c1 = testVector.getChildColumn(0) + val c2 = testVector.getChildColumn(1) + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + } +} From c66d64b3df9d9ffba0b16a62015680f6f876fc68 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 18 Sep 2017 12:12:35 -0700 Subject: [PATCH 17/37] [SPARK-14878][SQL] Trim characters string function support #### What changes were proposed in this pull request? This PR enhances the TRIM function support in Spark SQL by allowing the specification of trim characters set. Below is the SQL syntax : ``` SQL ::= TRIM ::= [ [ ] [ ] FROM ] ::= ::= LEADING | TRAILING | BOTH ::= ``` or ``` SQL LTRIM (source-exp [, trim-exp]) RTRIM (source-exp [, trim-exp]) ``` Here are the documentation link of support of this feature by other mainstream databases. - **Oracle:** [TRIM function](http://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2126.htm#OLADM704) - **DB2:** [TRIM scalar function](https://www.ibm.com/support/knowledgecenter/en/SSMKHH_10.0.0/com.ibm.etools.mft.doc/ak05270_.htm) - **MySQL:** [Trim function](http://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim) - **Oracle:** [ltrim](https://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2018.htm#OLADM594) - **DB2:** [ltrim](https://www.ibm.com/support/knowledgecenter/en/SSEPEK_11.0.0/sqlref/src/tpc/db2z_bif_ltrim.html) This PR is to implement the above enhancement. In the implementation, the design principle is to keep the changes to the minimum. Also, the exiting trim functions (which handles a special case, i.e., trimming space characters) are kept unchanged for performane reasons. #### How was this patch tested? The unit test cases are added in the following files: - UTF8StringSuite.java - StringExpressionsSuite.scala - sql/SQLQuerySuite.scala - StringFunctionsSuite.scala Author: Kevin Yu Closes #12646 from kevinyu98/spark-14878. --- .../apache/spark/unsafe/types/UTF8String.java | 93 ++++++ .../spark/unsafe/types/UTF8StringSuite.java | 57 ++++ .../spark/sql/catalyst/parser/SqlBase.g4 | 6 + .../expressions/stringExpressions.scala | 280 ++++++++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 24 +- .../expressions/StringExpressionsSuite.scala | 65 +++- .../sql/catalyst/parser/PlanParserSuite.scala | 18 ++ .../parser/TableIdentifierParserSuite.scala | 2 +- .../org/apache/spark/sql/functions.scala | 27 ++ .../spark/sql/StringFunctionsSuite.scala | 14 +- 10 files changed, 554 insertions(+), 32 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index dd67f15749add..76db0fb91e48a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -511,6 +511,21 @@ public UTF8String trim() { } } + /** + * Based on the given trim string, trim this string starting from both ends + * This method searches for each character in the source string, removes the character if it is found + * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight. + * It returns a new string in which both ends trim characters have been removed. + * @param trimString the trim character string + */ + public UTF8String trim(UTF8String trimString) { + if (trimString != null) { + return trimLeft(trimString).trimRight(trimString); + } else { + return null; + } + } + public UTF8String trimLeft() { int s = 0; // skip all of the space (0x20) in the left side @@ -523,6 +538,40 @@ public UTF8String trimLeft() { } } + /** + * Based on the given trim string, trim this string starting from left end + * This method searches each character in the source string starting from the left end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimLeft(UTF8String trimString) { + if (trimString == null) return null; + // the searching byte position in the source string + int srchIdx = 0; + // the first beginning byte position of a non-matching character + int trimIdx = 0; + + while (srchIdx < numBytes) { + UTF8String searchChar = copyUTF8String(srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set + if (trimString.find(searchChar, 0) >= 0) { + trimIdx += searchCharBytes; + } else { + // no matching, exit the search + break; + } + srchIdx += searchCharBytes; + } + + if (trimIdx >= numBytes) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(trimIdx, numBytes - 1); + } + } + public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side @@ -536,6 +585,50 @@ public UTF8String trimRight() { } } + /** + * Based on the given trim string, trim this string starting from right end + * This method searches each character in the source string starting from the right end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimRight(UTF8String trimString) { + if (trimString == null) return null; + int charIdx = 0; + // number of characters from the source string + int numChars = 0; + // array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + // build the position and length array + while (charIdx < numBytes) { + stringCharPos[numChars] = charIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); + charIdx += stringCharLen[numChars]; + numChars ++; + } + + // index trimEnd points to the first no matching byte position from the right side of the source string. + int trimEnd = numBytes - 1; + while (numChars > 0) { + UTF8String searchChar = + copyUTF8String(stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + if (trimString.find(searchChar, 0) >= 0) { + trimEnd -= stringCharLen[numChars - 1]; + } else { + break; + } + numChars --; + } + + if (trimEnd < 0) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(0, trimEnd); + } + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index c376371abdf90..f0860018d5642 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -730,4 +730,61 @@ public void testToLong() throws IOException { assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); } } + + @Test + public void trimBothWithTrimString() { + assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); + assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); + assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); + assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); + assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); + assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); + assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); + assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); + assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); + } + + @Test + public void trimLeftWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); + assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); + assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); + assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); + assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); + + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); + assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); + assertEquals(fromString("砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("据数"))); + assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 "))); + assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); + assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); + } + + @Test + public void trimRightWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); + assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); + assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); + assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); + assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); + + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); + assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数"))); + assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab"))); + assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); + assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); + } } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 33bc79a92b9e7..d0a54288780ea 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -580,6 +580,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression + FROM argument+=expression ')' #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference @@ -748,6 +750,7 @@ nonReserved | UNBOUNDED | WHEN | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP | DIRECTORY + | BOTH | LEADING | TRAILING ; SELECT: 'SELECT'; @@ -861,6 +864,9 @@ COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; IGNORE: 'IGNORE'; +BOTH: 'BOTH'; +LEADING: 'LEADING'; +TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7ab45a6ee8737..83de515079eea 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,6 +24,7 @@ import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -503,69 +504,304 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) +} + +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** - * A function that trim the spaces from both ends for the specified string. + * A function that takes a character string, removes the leading and trailing characters matching with any character + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * BOTH: removes any character from both ends of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trim() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trim() + } + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + }""") + } else { + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + }""") + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** - * A function that trim the spaces from left end for given string. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * LEADING: removes any character from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading space characters from `str`. + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimLeft() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trimLeft() + } + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""") + } else { + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + }""") + } + } +} + +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) } /** - * A function that trim the spaces from right end for given string. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * TRAILING: removes any character from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the trailing space characters from `str`. + _FUNC_(trimStr, str) - Removes the trailing string which contains the characters from the trim string from the `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); - SparkSQL + SparkSQL + > SELECT _FUNC_('LQSa', 'SSparkSQLS'); + SSpark """) -case class StringTrimRight(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimRight( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimRight() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "rtrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimRight()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trimRight() + } + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(); + }""") + } else { + val trimString = evals(1) + val getTrimRightFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimRightFunction + }""") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 891f61698f177..85b492e83446e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1179,6 +1179,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a (windowed) Function expression. */ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + def replaceFunctions( + funcID: FunctionIdentifier, + ctx: FunctionCallContext): FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case SqlBaseParser.BOTH => funcID + case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException("Function trim doesn't support with " + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID + } + } // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) @@ -1190,7 +1210,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } - val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct) + val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx) + val function = UnresolvedFunction(funcId, arguments, isDistinct) + // Check if the function is evaluated in a windowed context. ctx.windowSpec match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4f08031153ab0..18ef4bc37c2b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("concat") { @@ -406,26 +405,78 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } - test("TRIM/LTRIM/RTRIM") { + test("TRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim("aa", "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrim(Literal(" aabbtrimccc"), "ab cd"), "trim", create_row("bdef")) + checkEvaluation(StringTrim(Literal("a@>.,>"), "a.,@<>"), " ", create_row(" abdef ")) checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + checkEvaluation(StringTrim(s, "abd"), "ef", create_row("abdefa")) + checkEvaluation(StringTrim(s, "a"), "bdef", create_row("aaabdefaaaa")) + checkEvaluation(StringTrim(s, "SLSQ"), "park", create_row("SSparkSQLS")) + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s, "花世界"), "", create_row("花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花 花 世界 花 花 ")) + checkEvaluation(StringTrim(s, "a花世"), "界", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(s, "a@#( )"), "花花世界花花", create_row("aa()花花世界花花@ #")) + checkEvaluation(StringTrim(Literal("花trim"), "花 "), "trim", create_row(" abdef ")) + // scalastyle:on + checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null) + } + + test("LTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa "), "a "), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aabbcaaaa"), "ab"), "caaaa", create_row(" abdef ")) checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s, "a"), "bdefa", create_row("abdefa")) + checkEvaluation(StringTrimLeft(s, "a "), "bdefaaaa", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimLeft(s, "Spk"), "arkSQLS", create_row("SSparkSQLS")) + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s, "花"), "世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花 世"), "界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花"), "a花花世界花花 ", create_row("a花花世界花花 ")) + checkEvaluation(StringTrimLeft(s, "a花界"), "世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(s, "a世界"), "花花世界花花", create_row("花花世界花花")) + // scalastyle:on + checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null) + checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null) + } + + test("RTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("a"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("ab"), "ab"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("aabbaaaa %"), "a %"), "aabb", create_row("def")) checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s, "a"), "abdef", create_row("abdefa")) + checkEvaluation(StringTrimRight(s, "abf de"), "", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimRight(s, "S*&"), "SSparkSQL", create_row("SSparkSQLS*")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(Literal("a"), "花"), "a", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花"), "a"), "花", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花花世界"), "界花世"), "", create_row(" abdef ")) checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimRight(s, "花a#"), "花花世界", create_row("花花世界花花###aa花")) + checkEvaluation(StringTrimRight(s, "花"), "", create_row("花花花花")) + checkEvaluation(StringTrimRight(s, "花 界b@"), " 花花世", create_row(" 花花世 b界@花花 ")) // scalastyle:on - checkEvaluation(StringTrim(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null) } test("FORMAT") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index b0d2fb26a6006..306e6f2cfbd37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -651,4 +651,22 @@ class PlanParserSuite extends AnalysisTest { ) ) } + + test("TRIM function") { + intercept("select ltrim(both 'S' from 'SS abc S'", "missing ')' at ''") + intercept("select rtrim(trailing 'S' from 'SS abc S'", "missing ')' at ''") + + assertEqual( + "SELECT TRIM(BOTH '@$%&( )abc' FROM '@ $ % & ()abc ' )", + OneRowRelation().select('TRIM.function("@$%&( )abc", "@ $ % & ()abc ")) + ) + assertEqual( + "SELECT TRIM(LEADING 'c []' FROM '[ ccccbcc ')", + OneRowRelation().select('ltrim.function("c []", "[ ccccbcc ")) + ) + assertEqual( + "SELECT TRIM(TRAILING 'c&^,.' FROM 'bc...,,,&&&ccc')", + OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 76be6ee3f50bc..cc80a41df998d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47324ed9f2fb8..c6d0d86384b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2333,6 +2333,15 @@ object functions { */ def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + /** + * Trim the specified character string from left end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def ltrim(e: Column, trimString: String): Column = withExpr { + StringTrimLeft(e.expr, Literal(trimString)) + } + /** * Extract a specific group matched by a Java regex, from the specified string column. * If the regex did not match, or the specified group did not match, an empty string is returned. @@ -2410,6 +2419,15 @@ object functions { */ def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } + /** + * Trim the specified character string from right end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def rtrim(e: Column, trimString: String): Column = withExpr { + StringTrimRight(e.expr, Literal(trimString)) + } + /** * Returns the soundex code for the specified expression. * @@ -2477,6 +2495,15 @@ object functions { */ def trim(e: Column): Column = withExpr { StringTrim(e.expr) } + /** + * Trim the specified character from both ends for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def trim(e: Column, trimString: String): Column = withExpr { + StringTrim(e.expr, Literal(trimString)) + } + /** * Converts a string column to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index a12efc835691b..3d76b9ac33e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -161,12 +161,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") + val df = Seq((" example ", "", "example")).toDF("a", "b", "c") checkAnswer( df.select(ltrim($"a"), rtrim($"a"), trim($"a")), Row("example ", " example", "example")) + checkAnswer( + df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim($"c", "e")), + Row("xample", "exampl", "xampl")) + + checkAnswer( + df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim($"c", "elxp")), + Row("ample", "exa", "am")) + + checkAnswer( + df.select(trim($"c", "xyz")), + Row("example")) + checkAnswer( df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), Row("example ", " example", "example")) From 94f7e046a20cb6f802c7f2841ed2d2814cae49fa Mon Sep 17 00:00:00 2001 From: alexmnyc Date: Tue, 19 Sep 2017 10:05:59 +0800 Subject: [PATCH 18/37] [SPARK-22030][CORE] GraphiteSink fails to re-connect to Graphite instances behind an ELB or any other auto-scaled LB ## What changes were proposed in this pull request? Upgrade codahale metrics library so that Graphite constructor can re-resolve hosts behind a CNAME with re-tried DNS lookups. When Graphite is deployed behind an ELB, ELB may change IP addresses based on auto-scaling needs. Using current approach yields Graphite usage impossible, fixing for that use case - Upgrade to codahale 3.1.5 - Use new Graphite(host, port) constructor instead of new Graphite(new InetSocketAddress(host, port)) constructor ## How was this patch tested? The same logic is used for another project that is using the same configuration and code path, and graphite re-connect's behind ELB's are no longer an issue This are proposed changes for codahale lib - https://github.com/dropwizard/metrics/compare/v3.1.2...v3.1.5#diff-6916c85d2dd08d89fe771c952e3b8512R120. Specifically, https://github.com/dropwizard/metrics/blob/b4d246d34e8a059b047567848b3522567cbe6108/metrics-graphite/src/main/java/com/codahale/metrics/graphite/Graphite.java#L120 Please review http://spark.apache.org/contributing.html before opening a pull request. Author: alexmnyc Closes #19210 from alexmnyc/patch-1. --- .../org/apache/spark/metrics/sink/GraphiteSink.scala | 4 ++-- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 23e31823f4930..ac33e68abb490 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -68,8 +68,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { - case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) - case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some("udp") => new GraphiteUDP(host, port) + case Some("tcp") | None => new Graphite(host, port) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") } diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9ac753861dd84..e534e38213fb1 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -139,10 +139,10 @@ machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.3.0-shaded-protobuf.jar -metrics-core-3.1.2.jar -metrics-graphite-3.1.2.jar -metrics-json-3.1.2.jar -metrics-jvm-3.1.2.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index d39747e9ee058..02c5a19d173be 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -140,10 +140,10 @@ machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.3.0-shaded-protobuf.jar -metrics-core-3.1.2.jar -metrics-graphite-3.1.2.jar -metrics-json-3.1.2.jar -metrics-jvm-3.1.2.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar diff --git a/pom.xml b/pom.xml index af511c3e2e5df..0bbbf20a76d68 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 0.8.4 2.4.0 2.0.8 - 3.1.2 + 3.1.5 1.7.7 hadoop2 0.9.3 From 10f45b3c84ff7b3f1765dc6384a563c33d26548b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Sep 2017 11:53:50 +0800 Subject: [PATCH 19/37] [SPARK-22047][FLAKY TEST] HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? This PR tries to download Spark for each test run, to make sure each test run is absolutely isolated. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19265 from cloud-fan/test. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 01db9eb6f04f2..305f5b533d592 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -35,16 +35,18 @@ import org.apache.spark.util.Utils * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the * downloading for this spark version. */ -@org.scalatest.Ignore class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") - private val sparkTestingDir = "/tmp/spark-test" + // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to + // avoid downloading Spark of different versions in each run. + private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { Utils.deleteRecursively(wareHousePath) Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) super.afterAll() } @@ -53,7 +55,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" - Seq("wget", url, "-q", "-P", sparkTestingDir).! + Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath From a11db942aaf4c470a85f8a1b180f034f7a584254 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Tue, 19 Sep 2017 14:51:27 +0800 Subject: [PATCH 20/37] [SPARK-21923][CORE] Avoid calling reserveUnrollMemoryForThisTask for every record ## What changes were proposed in this pull request? When Spark persist data to Unsafe memory, we call the method `MemoryStore.putIteratorAsBytes`, which need synchronize the `memoryManager` for every record write. This implementation is not necessary, we can apply for more memory at a time to reduce unnecessary synchronization. ## How was this patch tested? Test case (with 1 executor 20 core): ```scala val start = System.currentTimeMillis() val data = sc.parallelize(0 until Integer.MAX_VALUE, 100) .persist(StorageLevel.OFF_HEAP) .count() println(System.currentTimeMillis() - start) ``` Test result: before | 27647 | 29108 | 28591 | 28264 | 27232 | after | 26868 | 26358 | 27767 | 26653 | 26693 | Author: Xianyang Liu Closes #19135 from ConeyLiu/memorystore. --- .../apache/spark/internal/config/package.scala | 15 +++++++++++++++ .../spark/storage/memory/MemoryStore.scala | 18 ++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0d3769a735869..e0f696080e566 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -385,4 +385,19 @@ package object config { .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") .createWithDefault(1024 * 1024) + + private[spark] val UNROLL_MEMORY_CHECK_PERIOD = + ConfigBuilder("spark.storage.unrollMemoryCheckPeriod") + .internal() + .doc("The memory check period is used to determine how often we should check whether " + + "there is a need to request more memory when we try to unroll the given block in memory.") + .longConf + .createWithDefault(16) + + private[spark] val UNROLL_MEMORY_GROWTH_FACTOR = + ConfigBuilder("spark.storage.unrollMemoryGrowthFactor") + .internal() + .doc("Memory to request as a multiple of the size that used to unroll the block.") + .doubleConf + .createWithDefault(1.5) } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 90e3af2d0ec74..eb2201d142ffb 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -29,6 +29,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{UNROLL_MEMORY_CHECK_PERIOD, UNROLL_MEMORY_GROWTH_FACTOR} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} @@ -190,11 +191,11 @@ private[spark] class MemoryStore( // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory - val memoryCheckPeriod = 16 + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size - val memoryGrowthFactor = 1.5 + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L // Underlying vector for unrolling the block @@ -325,6 +326,12 @@ private[spark] class MemoryStore( // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true + // Number of elements unrolled so far + var elementsUnrolled = 0L + // How often to check whether we need to request more memory + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) + // Memory to request as a multiple of current bbos size + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // Keep track of unroll memory used by this particular block / putIterator() operation @@ -359,7 +366,7 @@ private[spark] class MemoryStore( def reserveAdditionalMemoryIfNecessary(): Unit = { if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest @@ -370,7 +377,10 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold while (values.hasNext && keepUnrolling) { serializationStream.writeObject(values.next())(classTag) - reserveAdditionalMemoryIfNecessary() + elementsUnrolled += 1 + if (elementsUnrolled % memoryCheckPeriod == 0) { + reserveAdditionalMemoryIfNecessary() + } } // Make sure that we have enough memory to store the block. By this point, it is possible that From 7c92351f43ac4b1710e3c80c78f7978dad491ed2 Mon Sep 17 00:00:00 2001 From: Armin Date: Tue, 19 Sep 2017 10:06:32 +0100 Subject: [PATCH 21/37] [MINOR][CORE] Cleanup dead code and duplication in Mem. Management ## What changes were proposed in this pull request? * Removed the method `org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter#alignToWords`. It became unused as a result of 85b0a157543201895557d66306b38b3ca52f2151 (SPARK-15962) introducing word alignment for unsafe arrays. * Cleaned up duplicate code in memory management and unsafe sorters * The change extracting the exception paths is more than just cosmetics since it def. reduces the size the affected methods compile to ## How was this patch tested? * Build still passes after removing the method, grepping the codebase for `alignToWords` shows no reference to it anywhere either. * Dried up code is covered by existing tests. Author: Armin Closes #19254 from original-brownbear/cleanup-mem-consumer. --- .../apache/spark/memory/MemoryConsumer.java | 26 +++++++-------- .../spark/unsafe/map/BytesToBytesMap.java | 24 ++++++-------- .../unsafe/sort/UnsafeExternalSorter.java | 32 +++++++++---------- .../expressions/codegen/UnsafeRowWriter.java | 16 ---------- 4 files changed, 37 insertions(+), 61 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 4099fb01f2f95..0efae16e9838c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -89,13 +89,7 @@ public LongArray allocateArray(long size) { long required = size * 8L; MemoryBlock page = taskMemoryManager.allocatePage(required, this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += required; return new LongArray(page); @@ -116,13 +110,7 @@ public void freeArray(LongArray array) { protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += page.size(); return page; @@ -152,4 +140,14 @@ public void freeMemory(long size) { taskMemoryManager.releaseExecutionMemory(size, this); used -= size; } + + private void throwOom(final MemoryBlock page, final long required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 610ace30f8a62..4fadfe36cd716 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -283,13 +283,7 @@ private void advanceToNextPage() { } else { currentPage = null; if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } try { Closeables.close(reader, /* swallowIOException = */ false); @@ -307,13 +301,7 @@ private void advanceToNextPage() { public boolean hasNext() { if (numRecords == 0) { if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } } return numRecords > 0; @@ -403,6 +391,14 @@ public long spill(long numBytes) throws IOException { public void remove() { throw new UnsupportedOperationException(); } + + private void handleFailedDelete() { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists() && !file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index de4464080ef55..39eda00dd7efb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -219,15 +219,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, inMemSorter.numRecords()); spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - final int recordLength = sortedRecords.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); } final long spillSize = freeMemory(); @@ -488,6 +480,18 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + private static void spillIterator(UnsafeSorterIterator inMemIterator, + UnsafeSorterSpillWriter spillWriter) throws IOException { + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + } + /** * An UnsafeSorterIterator that support spilling. */ @@ -503,6 +507,7 @@ class SpillableIterator extends UnsafeSorterIterator { this.numRecords = inMemIterator.getNumRecords(); } + @Override public int getNumRecords() { return numRecords; } @@ -521,14 +526,7 @@ public long spill() throws IOException { // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); - while (inMemIterator.hasNext()) { - inMemIterator.loadNext(); - final Object baseObject = inMemIterator.getBaseObject(); - final long baseOffset = inMemIterator.getBaseOffset(); - final int recordLength = inMemIterator.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemIterator, spillWriter); spillWriters.add(spillWriter); nextUpstream = spillWriter.getReader(serializerManager); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 4776617043878..5d9515c0725da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -109,22 +109,6 @@ public void setOffsetAndSize(int ordinal, long currentCursor, long size) { Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } - // Do word alignment for this row and grow the row buffer if needed. - // todo: remove this after we make unsafe array data word align. - public void alignToWords(int numBytes) { - final int remainder = numBytes & 0x07; - - if (remainder > 0) { - final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); - - for (int i = 0; i < paddingBytes; i++) { - Platform.putByte(holder.buffer, holder.cursor, (byte) 0); - holder.cursor++; - } - } - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); Platform.putLong(holder.buffer, offset, 0L); From 1bc17a6b8add02772a8a0a1048ac6a01d045baf4 Mon Sep 17 00:00:00 2001 From: Taaffy <32072374+Taaffy@users.noreply.github.com> Date: Tue, 19 Sep 2017 10:20:04 +0100 Subject: [PATCH 22/37] [SPARK-22052] Incorrect Metric assigned in MetricsReporter.scala Current implementation for processingRate-total uses wrong metric: mistakenly uses inputRowsPerSecond instead of processedRowsPerSecond ## What changes were proposed in this pull request? Adjust processingRate-total from using inputRowsPerSecond to processedRowsPerSecond ## How was this patch tested? Built spark from source with proposed change and tested output with correct parameter. Before change the csv metrics file for inputRate-total and processingRate-total displayed the same values due to the error. After changing MetricsReporter.scala the processingRate-total csv file displayed the correct metric. processed rows per second Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Taaffy <32072374+Taaffy@users.noreply.github.com> Closes #19268 from Taaffy/patch-1. --- .../apache/spark/sql/execution/streaming/MetricsReporter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 5551d12fa8ad2..b84e6ce64c611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -40,7 +40,7 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.inputRowsPerSecond) + registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { From 581200af717bcefd11c9930ac063fe53c6fd2fde Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 19 Sep 2017 19:35:36 +0800 Subject: [PATCH 23/37] [SPARK-21428][SQL][FOLLOWUP] CliSessionState should point to the actual metastore not a dummy one ## What changes were proposed in this pull request? While running bin/spark-sql, we will reuse cliSessionState, but the Hive configurations generated here just points to a dummy meta store which actually should be the real one. And the warehouse is determined later in SharedState, HiveClient should respect this config changing in this case too. ## How was this patch tested? existing ut cc cloud-fan jiangxb1987 Author: Kent Yao Closes #19068 from yaooqinn/SPARK-21428-FOLLOWUP. --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 14 +++++++++++--- .../org/apache/spark/sql/hive/HiveUtils.scala | 6 +++--- .../spark/sql/hive/client/HiveClientImpl.scala | 15 +++++++++++++-- .../spark/sql/hive/client/HiveVersionSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 2 +- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 761e832ed14b8..832a15d09599f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,6 +37,8 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -81,11 +83,17 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(1) } + val sparkConf = new SparkConf(loadDefaults = true) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val extraConfigs = HiveUtils.formatTimeVarsForHiveClient(hadoopConf) + val cliConf = new HiveConf(classOf[SessionState]) - // Override the location of the metastore since this is only used for local execution. - HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false).foreach { - case (key, value) => cliConf.set(key, value) + (hadoopConf.iterator().asScala.map(kv => kv.getKey -> kv.getValue) + ++ sparkConf.getAll.toMap ++ extraConfigs).foreach { + case (k, v) => + cliConf.set(k, v) } + val sessionState = new CliSessionState(cliConf) sessionState.in = System.in diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 561c127a40bb6..80b9a3dc9605d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -176,9 +176,9 @@ private[spark] object HiveUtils extends Logging { } /** - * Configurations needed to create a [[HiveClient]]. + * Change time configurations needed to create a [[HiveClient]] into unified [[Long]] format. */ - private[hive] def hiveClientConfigurations(hadoopConf: Configuration): Map[String, String] = { + private[hive] def formatTimeVarsForHiveClient(hadoopConf: Configuration): Map[String, String] = { // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- // compatibility when users are trying to connecting to a Hive metastore of lower version, @@ -280,7 +280,7 @@ private[spark] object HiveUtils extends Logging { protected[hive] def newClientForMetadata( conf: SparkConf, hadoopConf: Configuration): HiveClient = { - val configurations = hiveClientConfigurations(hadoopConf) + val configurations = formatTimeVarsForHiveClient(hadoopConf) newClientForMetadata(conf, hadoopConf, configurations) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 426db6a4e1c12..c4e48c9360db7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Order} import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} @@ -132,14 +133,24 @@ private[hive] class HiveClientImpl( // in hive jars, which will turn off isolation, if SessionSate.detachSession is // called to remove the current state after that, hive client created later will initialize // its own state by newState() - Option(SessionState.get).getOrElse(newState()) + val ret = SessionState.get + if (ret != null) { + // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState + // instance constructed, we need to follow that change here. + Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) + } + ret + } else { + newState() + } } } // Log the default warehouse location. logInfo( s"Warehouse location for Hive client " + - s"(version ${version.fullVersion}) is ${conf.get("hive.metastore.warehouse.dir")}") + s"(version ${version.fullVersion}) is ${conf.getVar(ConfVars.METASTOREWAREHOUSE)}") private def newState(): SessionState = { val hiveConf = new HiveConf(classOf[SessionState]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index ed475a0261b0b..951ebfad4590e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -36,7 +36,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu hadoopConf.set("hive.metastore.schema.verification", "false") } HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) } override def suiteName: String = s"${super.suiteName}($version)" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 1d9c8da996fea..edb9a9ffbaaf6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -127,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client From 8319432af60b8e1dc00f08d794f7d80591e24d0c Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 19 Sep 2017 22:20:05 +0800 Subject: [PATCH 24/37] [SPARK-21917][CORE][YARN] Supporting adding http(s) resources in yarn mode ## What changes were proposed in this pull request? In the current Spark, when submitting application on YARN with remote resources `./bin/spark-shell --jars http://central.maven.org/maven2/com/github/swagger-akka-http/swagger-akka-http_2.11/0.10.1/swagger-akka-http_2.11-0.10.1.jar --master yarn-client -v`, Spark will be failed with: ``` java.io.IOException: No FileSystem for scheme: http at org.apache.hadoop.fs.FileSystem.getFileSystemClass(FileSystem.java:2586) at org.apache.hadoop.fs.FileSystem.createFileSystem(FileSystem.java:2593) at org.apache.hadoop.fs.FileSystem.access$200(FileSystem.java:91) at org.apache.hadoop.fs.FileSystem$Cache.getInternal(FileSystem.java:2632) at org.apache.hadoop.fs.FileSystem$Cache.get(FileSystem.java:2614) at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:370) at org.apache.hadoop.fs.Path.getFileSystem(Path.java:296) at org.apache.spark.deploy.yarn.Client.copyFileToRemote(Client.scala:354) at org.apache.spark.deploy.yarn.Client.org$apache$spark$deploy$yarn$Client$$distribute$1(Client.scala:478) at org.apache.spark.deploy.yarn.Client$$anonfun$prepareLocalResources$11$$anonfun$apply$6.apply(Client.scala:600) at org.apache.spark.deploy.yarn.Client$$anonfun$prepareLocalResources$11$$anonfun$apply$6.apply(Client.scala:599) at scala.collection.mutable.ArraySeq.foreach(ArraySeq.scala:74) at org.apache.spark.deploy.yarn.Client$$anonfun$prepareLocalResources$11.apply(Client.scala:599) at org.apache.spark.deploy.yarn.Client$$anonfun$prepareLocalResources$11.apply(Client.scala:598) at scala.collection.immutable.List.foreach(List.scala:381) at org.apache.spark.deploy.yarn.Client.prepareLocalResources(Client.scala:598) at org.apache.spark.deploy.yarn.Client.createContainerLaunchContext(Client.scala:848) at org.apache.spark.deploy.yarn.Client.submitApplication(Client.scala:173) ``` This is because `YARN#client` assumes resources are on the Hadoop compatible FS. To fix this problem, here propose to download remote http(s) resources to local and add this local downloaded resources to dist cache. This solution has one downside: remote resources are downloaded and uploaded again, but it only restricted to only remote http(s) resources, also the overhead is not so big. The advantages of this solution is that it is simple and the code changes restricts to only `SparkSubmit`. ## How was this patch tested? Unit test added, also verified in local cluster. Author: jerryshao Closes #19130 from jerryshao/SPARK-21917. --- .../apache/spark/deploy/DependencyUtils.scala | 9 ++- .../org/apache/spark/deploy/SparkSubmit.scala | 51 ++++++++++++++- .../spark/internal/config/package.scala | 10 +++ .../scala/org/apache/spark/util/Utils.scala | 3 + .../spark/deploy/SparkSubmitSuite.scala | 65 +++++++++++++++++++ docs/running-on-yarn.md | 9 +++ 6 files changed, 143 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 51c3d9b158cbe..ecc82d7ac8001 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -94,7 +94,7 @@ private[deploy] object DependencyUtils { hadoopConf: Configuration, secMgr: SecurityManager): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",") + Utils.stringToSeq(fileList) .map(downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr)) .mkString(",") } @@ -121,6 +121,11 @@ private[deploy] object DependencyUtils { uri.getScheme match { case "file" | "local" => path + case "http" | "https" | "ftp" if Utils.isTesting => + // This is only used for SparkSubmitSuite unit test. Instead of downloading file remotely, + // return a dummy local path instead. + val file = new File(uri.getPath) + new File(targetDir, file.getName).toURI.toString case _ => val fname = new Path(uri).getName() val localFile = Utils.doFetchFile(uri.toString(), targetDir, fname, sparkConf, secMgr, @@ -131,7 +136,7 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") - paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + Utils.stringToSeq(paths).flatMap { path => val uri = Utils.resolveURI(path) uri.getScheme match { case "local" | "http" | "https" | "ftp" => Array(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ea9c9bdaede76..286a4379d2040 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -25,11 +25,11 @@ import java.text.ParseException import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Properties +import scala.util.{Properties, Try} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.ivy.Ivy @@ -48,6 +48,7 @@ import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ @@ -367,6 +368,52 @@ object SparkSubmit extends CommandLineUtils with Logging { }.orNull } + // When running in YARN, for some remote resources with scheme: + // 1. Hadoop FileSystem doesn't support them. + // 2. We explicitly bypass Hadoop FileSystem with "spark.yarn.dist.forceDownloadSchemes". + // We will download them to local disk prior to add to YARN's distributed cache. + // For yarn client mode, since we already download them with above code, so we only need to + // figure out the local path and replace the remote one. + if (clusterManager == YARN) { + sparkConf.setIfMissing(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val secMgr = new SecurityManager(sparkConf) + val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) + + def shouldDownload(scheme: String): Boolean = { + forceDownloadSchemes.contains(scheme) || + Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure + } + + def downloadResource(resource: String): String = { + val uri = Utils.resolveURI(resource) + uri.getScheme match { + case "local" | "file" => resource + case e if shouldDownload(e) => + val file = new File(targetDir, new Path(uri).getName) + if (file.exists()) { + file.toURI.toString + } else { + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + } + case _ => uri.toString + } + } + + args.primaryResource = Option(args.primaryResource).map { downloadResource }.orNull + args.files = Option(args.files).map { files => + Utils.stringToSeq(files).map(downloadResource).mkString(",") + }.orNull + args.pyFiles = Option(args.pyFiles).map { pyFiles => + Utils.stringToSeq(pyFiles).map(downloadResource).mkString(",") + }.orNull + args.jars = Option(args.jars).map { jars => + Utils.stringToSeq(jars).map(downloadResource).mkString(",") + }.orNull + args.archives = Option(args.archives).map { archives => + Utils.stringToSeq(archives).map(downloadResource).mkString(",") + }.orNull + } + // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e0f696080e566..44a2815b81a73 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -400,4 +400,14 @@ package object config { .doc("Memory to request as a multiple of the size that used to unroll the block.") .doubleConf .createWithDefault(1.5) + + private[spark] val FORCE_DOWNLOAD_SCHEMES = + ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") + .doc("Comma-separated list of schemes for which files will be downloaded to the " + + "local disk prior to being added to YARN's distributed cache. For use in cases " + + "where the YARN service does not support schemes that are supported by Spark, like http, " + + "https and ftp.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bc08808a4d292..836e33c36d9a1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2684,6 +2684,9 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs.toArray) } + def stringToSeq(str: String): Seq[String] = { + str.split(",").map(_.trim()).filter(_.nonEmpty) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4d69ce844d2ea..ad801bf8519a6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -897,6 +897,71 @@ class SparkSubmitSuite sysProps("spark.submit.pyFiles") should (startWith("/")) } + test("download remote resource if it is not supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + } + + test("avoid downloading remote resource if it is supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + } + + test("force download from blacklisted schemes") { + testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + } + + private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, + supportMockHttpFs: Boolean): Unit = { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + if (supportMockHttpFs) { + hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) + hadoopConf.set("fs.http.impl.disable.cache", "true") + } + + val tmpDir = Utils.createTempDir() + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpS3Jar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpS3JarPath = s"s3a://${new File(tmpS3Jar.toURI).getAbsolutePath}" + val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", + s"s3a://$mainResource" + ) ++ ( + if (isHttpSchemeBlacklisted) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") + } else { + Nil + } + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + + // The URI of remote S3 resource should still be remote. + assert(jars.contains(tmpS3JarPath)) + + if (supportMockHttpFs) { + // If Http FS is supported by yarn service, the URI of remote http resource should + // still be remote. + assert(jars.contains(tmpHttpJarPath)) + } else { + // If Http FS is not supported by yarn service, or http scheme is configured to be force + // downloading, the URI of remote http resource should be changed to a local one. + val jarName = new File(tmpHttpJar.toURI).getName + val localHttpJar = jars.filter(_.contains(jarName)) + localHttpJar.size should be(1) + localHttpJar.head should startWith("file:") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4a74556d4f26..432639588cc2b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,6 +211,15 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. + + spark.yarn.dist.forceDownloadSchemes + (none) + + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not + support schemes that are supported by Spark, like http, https and ftp. + + spark.executor.instances 2 From 2f962422a25020582c915e15819f91f43c0b9d68 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 19 Sep 2017 22:22:35 +0800 Subject: [PATCH 25/37] [MINOR][ML] Remove unnecessary default value setting for evaluators. ## What changes were proposed in this pull request? Remove unnecessary default value setting for all evaluators, as we have set them in corresponding _HasXXX_ base classes. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #19262 from yanboliang/evaluation. --- python/pyspark/ml/evaluation.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 7cb8d62f212cb..09cdf9b6a629a 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -146,8 +146,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) - self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", - metricName="areaUnderROC") + self._setDefault(metricName="areaUnderROC") kwargs = self._input_kwargs self._set(**kwargs) @@ -224,8 +223,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(RegressionEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) - self._setDefault(predictionCol="prediction", labelCol="label", - metricName="rmse") + self._setDefault(metricName="rmse") kwargs = self._input_kwargs self._set(**kwargs) @@ -297,8 +295,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - self._setDefault(predictionCol="prediction", labelCol="label", - metricName="f1") + self._setDefault(metricName="f1") kwargs = self._input_kwargs self._set(**kwargs) From d5aefa83ad8608fbea7c08e8d9164f8bee00863d Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 19 Sep 2017 09:27:05 -0700 Subject: [PATCH 26/37] [SPARK-21338][SQL] implement isCascadingTruncateTable() method in AggregatedDialect ## What changes were proposed in this pull request? org.apache.spark.sql.jdbc.JdbcDialect's method: def isCascadingTruncateTable(): Option[Boolean] = None is not overriden in org.apache.spark.sql.jdbc.AggregatedDialect class. Because of this issue, when you add more than one dialect Spark doesn't truncate table because isCascadingTruncateTable always returns default None for Aggregated Dialect. Will implement isCascadingTruncateTable in AggregatedDialect class in this PR. ## How was this patch tested? In JDBCSuite, inside test("Aggregated dialects"), will add one line to test AggregatedDialect.isCascadingTruncateTable Author: Huaxin Gao Closes #19256 from huaxingao/spark-21338. --- .../scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala | 4 ++++ .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7f..7432a1538ce97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -41,4 +41,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect override def getJDBCType(dt: DataType): Option[JdbcType] = { dialects.flatMap(_.getJDBCType(dt)).headOption } + + override def isCascadingTruncateTable(): Option[Boolean] = { + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 689f4106824aa..fd12bb9e530b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -740,11 +740,13 @@ class JDBCSuite extends SparkFunSuite } else { None } + override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.isCascadingTruncateTable() === Some(true)) } test("DB2Dialect type mapping") { From ee13f3e3dc3faa5152cefa91c22f8aaa8e425bb4 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Tue, 19 Sep 2017 14:19:13 -0700 Subject: [PATCH 27/37] [SPARK-21969][SQL] CommandUtils.updateTableStats should call refreshTable ## What changes were proposed in this pull request? Tables in the catalog cache are not invalidated once their statistics are updated. As a consequence, existing sessions will use the cached information even though it is not valid anymore. Consider and an example below. ``` // step 1 spark.range(100).write.saveAsTable("tab1") // step 2 spark.sql("analyze table tab1 compute statistics") // step 3 spark.sql("explain cost select distinct * from tab1").show(false) // step 4 spark.range(100).write.mode("append").saveAsTable("tab1") // step 5 spark.sql("explain cost select distinct * from tab1").show(false) ``` After step 3, the table will be present in the catalog relation cache. Step 4 will correctly update the metadata inside the catalog but will NOT invalidate the cache. By the way, ``spark.sql("analyze table tab1 compute statistics")`` between step 3 and step 4 would also solve the problem. ## How was this patch tested? Current and additional unit tests. Author: aokolnychyi Closes #19252 from aokolnychyi/spark-21969. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 + .../command/AnalyzeColumnCommand.scala | 3 - .../command/AnalyzeTableCommand.scala | 2 - .../spark/sql/StatisticsCollectionSuite.scala | 73 +++++++++++++++++++ .../sql/StatisticsCollectionTestBase.scala | 14 +++- 5 files changed, 87 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 0908d68d25649..9407b727bca4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -377,6 +377,8 @@ class SessionCatalog( requireDbExists(db) requireTableExists(tableIdentifier) externalCatalog.alterTableStats(db, table, newStats) + // Invalidate the table relation cache + refreshTable(identifier) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6588993ef9ad9..caf12ad745bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -56,9 +56,6 @@ case class AnalyzeColumnCommand( sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics)) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 04715bd314d4d..58b53e8b1c551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -48,8 +48,6 @@ case class AnalyzeTableCommand( val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) if (newStats.isDefined) { sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) } Seq.empty[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9e459ed00c8d5..2fc92f4aff92e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -261,6 +261,10 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched1.get.sizeInBytes == 0) assert(fetched1.get.colStats.size == 2) + // table lookup will make the table cached + spark.table(table) + assert(isTableInCatalogCache(table)) + // insert into command sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") if (autoUpdate) { @@ -270,9 +274,78 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) } + + // check that tableRelationCache inside the catalog was invalidated after insert + assert(!isTableInCatalogCache(table)) + } + } + } + } + + test("invalidation of tableRelationCache after inserts") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + val initialSizeInBytes = getTableFromCatalogCache(table).stats.sizeInBytes + spark.range(100).write.mode(SaveMode.Append).saveAsTable(table) + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 2 * initialSizeInBytes) + } + } + } + } + + test("invalidation of tableRelationCache after table truncation") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + sql(s"TRUNCATE TABLE $table") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) } } } } + test("invalidation of tableRelationCache after alter table add partition") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTempDir { dir => + withTable(table) { + val path = dir.getCanonicalPath + sql(s""" + |CREATE TABLE $table (col1 int, col2 int) + |USING PARQUET + |PARTITIONED BY (col2) + |LOCATION '${dir.toURI}'""".stripMargin) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) + spark.catalog.recoverPartitions(table) + val df = Seq((1, 2), (1, 2)).toDF("col2", "col1") + df.write.parquet(s"$path/col2=1") + sql(s"ALTER TABLE $table ADD PARTITION (col2=1) LOCATION '${dir.toURI}'") + spark.table(table) + val cachedTable = getTableFromCatalogCache(table) + val cachedTableSizeInBytes = cachedTable.stats.sizeInBytes + val defaultSizeInBytes = conf.defaultSizeInBytes + if (autoUpdate) { + assert(cachedTableSizeInBytes != defaultSizeInBytes && cachedTableSizeInBytes > 0) + } else { + assert(cachedTableSizeInBytes == defaultSizeInBytes) + } + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 5916cd76b8789..a2f63edd786bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -23,9 +23,9 @@ import java.sql.{Date, Timestamp} import scala.collection.mutable import scala.util.Random -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf @@ -85,6 +85,16 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) } + def getTableFromCatalogCache(tableName: String): LogicalPlan = { + val catalog = spark.sessionState.catalog + val qualifiedTableName = QualifiedTableName(catalog.getCurrentDatabase, tableName) + catalog.getCachedTable(qualifiedTableName) + } + + def isTableInCatalogCache(tableName: String): Boolean = { + getTableFromCatalogCache(tableName) != null + } + def getCatalogStatistics(tableName: String): CatalogStatistics = { getCatalogTable(tableName).stats.get } From 718bbc939037929ef5b8f4b4fe10aadfbab4408e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 20 Sep 2017 10:51:00 +0900 Subject: [PATCH 28/37] [SPARK-22067][SQL] ArrowWriter should use position when setting UTF8String ByteBuffer ## What changes were proposed in this pull request? The ArrowWriter StringWriter was setting Arrow data using a position of 0 instead of the actual position in the ByteBuffer. This was currently working because of a bug ARROW-1443, and has been fixed as of Arrow 0.7.0. Testing with this version revealed the error in ArrowConvertersSuite test string conversion. ## How was this patch tested? Existing tests, manually verified working with Arrow 0.7.0 Author: Bryan Cutler Closes #19284 from BryanCutler/arrow-ArrowWriter-StringWriter-position-SPARK-22067. --- .../org/apache/spark/sql/execution/arrow/ArrowWriter.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 11ba04d2ce9a7..0b740735ffe19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -234,8 +234,9 @@ private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extend override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val utf8 = input.getUTF8String(ordinal) + val utf8ByteBuffer = utf8.getByteBuffer // todo: for off-heap UTF8String, how to pass in to arrow without copy? - valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) + valueMutator.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) } } From c6ff59a230758b409fa9cc548b7d283eeb7ebe5d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 20 Sep 2017 13:41:29 +0800 Subject: [PATCH 29/37] [SPARK-18838][CORE] Add separate listener queues to LiveListenerBus. This change modifies the live listener bus so that all listeners are added to queues; each queue has its own thread to dispatch events, making it possible to separate slow listeners from other more performance-sensitive ones. The public API has not changed - all listeners added with the existing "addListener" method, which after this change mostly means all user-defined listeners, end up in a default queue. Internally, there's an API allowing listeners to be added to specific queues, and that API is used to separate the internal Spark listeners into 3 categories: application status listeners (e.g. UI), executor management (e.g. dynamic allocation), and the event log. The queueing logic, while abstracted away in a separate class, is kept as much as possible hidden away from consumers. Aside from choosing their queue, there's no code change needed to take advantage of queues. Test coverage relies on existing tests; a few tests had to be tweaked because they relied on `LiveListenerBus.postToAll` being synchronous, and the change makes that method asynchronous. Other tests were simplified not to use the asynchronous LiveListenerBus. Author: Marcelo Vanzin Closes #19211 from vanzin/SPARK-18838. --- .../spark/ExecutorAllocationManager.scala | 2 +- .../org/apache/spark/HeartbeatReceiver.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 13 +- .../spark/scheduler/AsyncEventQueue.scala | 196 +++++++++++++ .../spark/scheduler/LiveListenerBus.scala | 277 +++++++----------- .../scala/org/apache/spark/ui/SparkUI.scala | 24 +- .../ExecutorAllocationManagerSuite.scala | 128 ++++---- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../scheduler/EventLoggingListenerSuite.scala | 6 +- .../spark/scheduler/SparkListenerSuite.scala | 95 ++++-- .../spark/ui/storage/StorageTabSuite.scala | 4 +- .../streaming/StreamingQueryListenerBus.scala | 2 +- .../spark/sql/internal/SharedState.scala | 3 +- .../spark/streaming/StreamingContext.scala | 3 +- .../scheduler/StreamingListenerBus.scala | 2 +- .../streaming/StreamingContextSuite.scala | 4 +- 16 files changed, 473 insertions(+), 290 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 7a5fb9a802354..119b426a9af34 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -217,7 +217,7 @@ private[spark] class ExecutorAllocationManager( * the scheduling task. */ def start(): Unit = { - listenerBus.addListener(listener) + listenerBus.addToManagementQueue(listener) val scheduleTask = new Runnable() { override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 5242ab6f55235..ff960b396dbf1 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -63,7 +63,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) this(sc, new SystemClock) } - sc.addSparkListener(this) + sc.listenerBus.addToManagementQueue(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 136f0af7b2c9e..1821bc87bf626 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -419,7 +419,7 @@ class SparkContext(config: SparkConf) extends Logging { // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. _jobProgressListener = new JobProgressListener(_conf) - listenerBus.addListener(jobProgressListener) + listenerBus.addToStatusQueue(jobProgressListener) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -442,7 +442,7 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI @@ -522,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging { new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, _conf, _hadoopConfiguration) logger.start() - listenerBus.addListener(logger) + listenerBus.addToEventLogQueue(logger) Some(logger) } else { None @@ -1563,7 +1563,7 @@ class SparkContext(config: SparkConf) extends Logging { */ @DeveloperApi def addSparkListener(listener: SparkListenerInterface) { - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) } /** @@ -1879,8 +1879,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop SparkContext within listener bus thread.") } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. @@ -2378,7 +2377,7 @@ class SparkContext(config: SparkConf) extends Logging { " parameter from breaking Spark's ability to find a valid constructor.") } } - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) logInfo(s"Registered listener $className") } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala new file mode 100644 index 0000000000000..8605e1da161c7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import com.codahale.metrics.{Gauge, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +/** + * An asynchronous queue for events. All events posted to this queue will be delivered to the child + * listeners in a separate thread. + * + * Delivery will only begin when the `start()` method is called. The `stop()` method should be + * called when no more events need to be delivered. + */ +private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) + extends SparkListenerBus + with Logging { + + import AsyncEventQueue._ + + // Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if + // it's perpetually being added to more quickly than it's being drained. + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( + conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + + // Keep the event count separately, so that waitUntilEmpty() can be implemented properly; + // this allows that method to return only when the events in the queue have been fully + // processed (instead of just dequeued). + private val eventCount = new AtomicLong() + + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + + private val logDroppedEvent = new AtomicBoolean(false) + + private var sc: SparkContext = null + + private val started = new AtomicBoolean(false) + private val stopped = new AtomicBoolean(false) + + private val droppedEvents = metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents") + private val processingTime = metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime") + + // Remove the queue size gauge first, in case it was created by a previous incarnation of + // this queue that was removed from the listener bus. + metrics.metricRegistry.remove(s"queue.$name.size") + metrics.metricRegistry.register(s"queue.$name.size", new Gauge[Int] { + override def getValue: Int = eventQueue.size() + }) + + private val dispatchThread = new Thread(s"spark-listener-group-$name") { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + dispatch() + } + } + + private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { + try { + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() + } + eventCount.decrementAndGet() + next = eventQueue.take() + } + eventCount.decrementAndGet() + } catch { + case ie: InterruptedException => + logInfo(s"Stopping listener queue $name.", ie) + } + } + + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + } + + /** + * Start an asynchronous thread to dispatch events to the underlying listeners. + * + * @param sc Used to stop the SparkContext in case the async dispatcher fails. + */ + private[scheduler] def start(sc: SparkContext): Unit = { + if (started.compareAndSet(false, true)) { + this.sc = sc + dispatchThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but new + * events will be dropped. + */ + private[scheduler] def stop(): Unit = { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + eventQueue.put(POISON_PILL) + eventCount.incrementAndGet() + } + dispatchThread.join() + } + + def post(event: SparkListenerEvent): Unit = { + if (stopped.get()) { + return + } + + eventCount.incrementAndGet() + if (eventQueue.offer(event)) { + return + } + + eventCount.decrementAndGet() + droppedEvents.inc() + droppedEventsCounter.incrementAndGet() + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError(s"Dropping event from queue $name. " + + "This likely means one of the listeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") + } + logTrace(s"Dropping event $event") + + val droppedCount = droppedEventsCounter.get + if (droppedCount > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedCount, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + val previous = new java.util.Date(prevLastReportTimestamp) + logWarning(s"Dropped $droppedEvents events from $name since $previous.") + } + } + } + } + + /** + * For testing only. Wait until there are no more events in the queue. + * + * @return true if the queue is empty. + */ + def waitUntilEmpty(deadline: Long): Boolean = { + while (eventCount.get() != 0) { + if (System.currentTimeMillis > deadline) { + return false + } + Thread.sleep(10) + } + true + } + +} + +private object AsyncEventQueue { + + val POISON_PILL = new SparkListenerEvent() { } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 7d5e9809dd7b2..2f93c497c5771 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,20 +17,22 @@ package org.apache.spark.scheduler +import java.util.{List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.reflect.ClassTag import scala.util.DynamicVariable -import com.codahale.metrics.{Counter, Gauge, MetricRegistry, Timer} +import com.codahale.metrics.{Counter, MetricRegistry, Timer} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.Utils /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -39,20 +41,13 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { - - self => +private[spark] class LiveListenerBus(conf: SparkConf) { import LiveListenerBus._ private var sparkContext: SparkContext = _ - // Cap the capacity of the event queue so we get an explicit error (rather than - // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private val eventQueue = - new LinkedBlockingQueue[SparkListenerEvent](conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) - - private[spark] val metrics = new LiveListenerBusMetrics(conf, eventQueue) + private[spark] val metrics = new LiveListenerBusMetrics(conf) // Indicate if `start()` is called private val started = new AtomicBoolean(false) @@ -65,53 +60,76 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { /** When `droppedEventsCounter` was logged last time in milliseconds. */ @volatile private var lastReportTimestamp = 0L - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false - - private val logDroppedEvent = new AtomicBoolean(false) - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread(name) { - setDaemon(true) - override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - LiveListenerBus.withinListenerThread.withValue(true) { - val timer = metrics.eventProcessingTime - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - val timerContext = timer.time() - try { - postToAll(event) - } finally { - timerContext.stop() - } - } finally { - self.synchronized { - processingEvent = false - } - } + private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + + /** Add a listener to queue shared by all non-internal listeners. */ + def addToSharedQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, SHARED_QUEUE) + } + + /** Add a listener to the executor management queue. */ + def addToManagementQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EXECUTOR_MANAGEMENT_QUEUE) + } + + /** Add a listener to the application status queue. */ + def addToStatusQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, APP_STATUS_QUEUE) + } + + /** Add a listener to the event log queue. */ + def addToEventLogQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EVENT_LOG_QUEUE) + } + + /** + * Add a listener to a specific queue, creating a new queue if needed. Queues are independent + * of each other (each one uses a separate thread for delivering events), allowing slower + * listeners to be somewhat isolated from others. + */ + private def addToQueue(listener: SparkListenerInterface, queue: String): Unit = synchronized { + if (stopped.get()) { + throw new IllegalStateException("LiveListenerBus is stopped.") + } + + queues.asScala.find(_.name == queue) match { + case Some(queue) => + queue.addListener(listener) + + case None => + val newQueue = new AsyncEventQueue(queue, conf, metrics) + newQueue.addListener(listener) + if (started.get()) { + newQueue.start(sparkContext) } - } + queues.add(newQueue) } } - override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { - metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + def removeListener(listener: SparkListenerInterface): Unit = synchronized { + // Remove listener from all queues it was added to, and stop queues that have become empty. + queues.asScala + .filter { queue => + queue.removeListener(listener) + queue.listeners.isEmpty() + } + .foreach { toRemove => + if (started.get() && !stopped.get()) { + toRemove.stop() + } + queues.remove(toRemove) + } + } + + /** Post an event to all queues. */ + def post(event: SparkListenerEvent): Unit = { + if (!stopped.get()) { + metrics.numEventsPosted.inc() + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } + } } /** @@ -123,46 +141,14 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { * * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = { - if (started.compareAndSet(false, true)) { - sparkContext = sc - metricsSystem.registerSource(metrics) - listenerThread.start() - } else { - throw new IllegalStateException(s"$name already started!") - } - } - - def post(event: SparkListenerEvent): Unit = { - if (stopped.get) { - // Drop further events to make `listenerThread` exit ASAP - logDebug(s"$name has already stopped! Dropping event $event") - return - } - metrics.numEventsPosted.inc() - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) + def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = synchronized { + if (!started.compareAndSet(false, true)) { + throw new IllegalStateException("LiveListenerBus already started.") } - val droppedEvents = droppedEventsCounter.get - if (droppedEvents > 0) { - // Don't log too frequently - if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { - // There may be multiple threads trying to decrease droppedEventsCounter. - // Use "compareAndSet" to make sure only one thread can win. - // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and - // then that thread will update it. - if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { - val prevLastReportTimestamp = lastReportTimestamp - lastReportTimestamp = System.currentTimeMillis() - logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + - new java.util.Date(prevLastReportTimestamp)) - } - } - } + this.sparkContext = sc + queues.asScala.foreach(_.start(sc)) + metricsSystem.registerSource(metrics) } /** @@ -173,80 +159,64 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { */ @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") + val deadline = System.currentTimeMillis + timeoutMillis + queues.asScala.foreach { queue => + if (!queue.waitUntilEmpty(deadline)) { + throw new TimeoutException(s"The event queue is not empty after $timeoutMillis ms.") } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) } } - /** - * For testing only. Return whether the listener daemon thread is still alive. - * Exposed for testing. - */ - def listenerThreadIsAlive: Boolean = listenerThread.isAlive - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } - /** * Stop the listener bus. It will wait until the queued events have been processed, but drop the * new events after stopping. */ def stop(): Unit = { if (!started.get()) { - throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + throw new IllegalStateException(s"Attempted to stop bus that has not yet started!") } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet + + if (!stopped.compareAndSet(false, true)) { + return } - } - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: SparkListenerEvent): Unit = { - metrics.numDroppedEvents.inc() - droppedEventsCounter.incrementAndGet() - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with " + - "the rate at which tasks are being started by the scheduler.") + synchronized { + queues.asScala.foreach(_.stop()) + queues.clear() } - logTrace(s"Dropping event $event") } + + // For testing only. + private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): Seq[T] = { + queues.asScala.flatMap { queue => queue.findListenersByClass[T]() } + } + + // For testing only. + private[spark] def listeners: JList[SparkListenerInterface] = { + queues.asScala.flatMap(_.listeners.asScala).asJava + } + + // For testing only. + private[scheduler] def activeQueues(): Set[String] = { + queues.asScala.map(_.name).toSet + } + } private[spark] object LiveListenerBus { // Allows for Context to check whether stop() call is made within listener thread val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) - /** The thread name of Spark listener bus */ - val name = "SparkListenerBus" + private[scheduler] val SHARED_QUEUE = "shared" + + private[scheduler] val APP_STATUS_QUEUE = "appStatus" + + private[scheduler] val EXECUTOR_MANAGEMENT_QUEUE = "executorManagement" + + private[scheduler] val EVENT_LOG_QUEUE = "eventLog" } -private[spark] class LiveListenerBusMetrics( - conf: SparkConf, - queue: LinkedBlockingQueue[_]) +private[spark] class LiveListenerBusMetrics(conf: SparkConf) extends Source with Logging { override val sourceName: String = "LiveListenerBus" @@ -260,25 +230,6 @@ private[spark] class LiveListenerBusMetrics( */ val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted")) - /** - * The total number of events that were dropped without being delivered to listeners. - */ - val numDroppedEvents: Counter = metricRegistry.counter(MetricRegistry.name("numEventsDropped")) - - /** - * The amount of time taken to post a single event to all listeners. - */ - val eventProcessingTime: Timer = metricRegistry.timer(MetricRegistry.name("eventProcessingTime")) - - /** - * The number of messages waiting in the queue. - */ - val queueSize: Gauge[Int] = { - metricRegistry.register(MetricRegistry.name("queueSize"), new Gauge[Int]{ - override def getValue: Int = queue.size() - }) - } - // Guarded by synchronization. private val perListenerClassTimers = mutable.Map[String, Timer]() @@ -303,5 +254,5 @@ private[spark] class LiveListenerBusMetrics( } } } -} +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f3fcf2778d39e..6e94073238a56 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -162,13 +162,14 @@ private[spark] object SparkUI { def createLiveUI( sc: SparkContext, conf: SparkConf, - listenerBus: SparkListenerBus, jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, startTime: Long): SparkUI = { - create(Some(sc), conf, listenerBus, securityManager, appName, - jobProgressListener = Some(jobProgressListener), startTime = startTime) + create(Some(sc), conf, + sc.listenerBus.addToStatusQueue, + securityManager, appName, jobProgressListener = Some(jobProgressListener), + startTime = startTime) } def createHistoryUI( @@ -179,8 +180,7 @@ private[spark] object SparkUI { basePath: String, lastUpdateTime: Option[Long], startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, + val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, lastUpdateTime = lastUpdateTime, startTime = startTime) val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], @@ -202,7 +202,7 @@ private[spark] object SparkUI { private def create( sc: Option[SparkContext], conf: SparkConf, - listenerBus: SparkListenerBus, + addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, basePath: String = "", @@ -212,7 +212,7 @@ private[spark] object SparkUI { val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { val listener = new JobProgressListener(conf) - listenerBus.addListener(listener) + addListenerFn(listener) listener } @@ -222,11 +222,11 @@ private[spark] object SparkUI { val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - listenerBus.addListener(environmentListener) - listenerBus.addListener(storageStatusListener) - listenerBus.addListener(executorsListener) - listenerBus.addListener(storageListener) - listenerBus.addListener(operationGraphListener) + addListenerFn(environmentListener) + addListenerFn(storageStatusListener) + addListenerFn(executorsListener) + addListenerFn(storageListener) + addListenerFn(operationGraphListener) new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, executorsListener, _jobProgressListener, storageListener, operationGraphListener, diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 7da4bae0ab7eb..a91e09b7cb69f 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -49,6 +49,11 @@ class ExecutorAllocationManagerSuite contexts.foreach(_.stop()) } + private def post(bus: LiveListenerBus, event: SparkListenerEvent): Unit = { + bus.post(event) + bus.waitUntilEmpty(1000) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("myDummyLocalExternalClusterManager") @@ -95,7 +100,7 @@ class ExecutorAllocationManagerSuite test("add executors") { sc = createSparkContext(1, 10, 1) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsTarget(manager) === 1) @@ -140,7 +145,7 @@ class ExecutorAllocationManagerSuite test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 5))) // Verify that we're capped at number of tasks in the stage assert(numExecutorsTarget(manager) === 0) @@ -156,10 +161,10 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 3))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 6) @@ -172,9 +177,9 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that re-running a task doesn't blow things up - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 3))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 9) assert(numExecutorsToAdd(manager) === 2) @@ -183,7 +188,7 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) assert(addExecutors(manager) === 0) assert(numExecutorsTarget(manager) === 10) } @@ -193,13 +198,13 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get // Verify that we're capped at number of tasks including the speculative ones in the stage - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) @@ -210,13 +215,13 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 0) assert(numExecutorsToAdd(manager) === 1) // Verify that running a speculative task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 0) assert(numExecutorsToAdd(manager) === 1) @@ -225,7 +230,7 @@ class ExecutorAllocationManagerSuite test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 5))) assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) @@ -236,15 +241,15 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task1Info)) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 2) val task2Info = createTaskInfo(1, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task2Info)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -352,21 +357,22 @@ class ExecutorAllocationManagerSuite sc = createSparkContext(5, 12, 5) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 8))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 8))) // Remove when numExecutorsTarget is the same as the current number of executors assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) (1 to 8).map { i => createTaskInfo(i, i, s"$i") }.foreach { - info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 8) assert(maxNumExecutorsNeeded(manager) == 8) assert(!removeExecutor(manager, "1")) // won't work since numExecutorsTarget == numExecutors // Remove executors when numExecutorsTarget is lower than current number of executors - (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { - info => sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, Success, info, null)) } + (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, info, null)) + } adjustRequestedExecutors(manager) assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 5) @@ -378,7 +384,7 @@ class ExecutorAllocationManagerSuite onExecutorRemoved(manager, "3") // numExecutorsTarget is lower than minNumExecutors - sc.listenerBus.postToAll( + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, createTaskInfo(4, 4, "4"), null)) assert(executorIds(manager).size === 5) assert(numExecutorsTarget(manager) === 5) @@ -390,7 +396,7 @@ class ExecutorAllocationManagerSuite test ("interleaving add and remove") { sc = createSparkContext(5, 12, 5) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -569,7 +575,7 @@ class ExecutorAllocationManagerSuite val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) @@ -682,26 +688,26 @@ class ExecutorAllocationManagerSuite // Starting a stage should start the add timer val numTasks = 10 - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, numTasks))) assert(addTime(manager) !== NOT_SET) // Starting a subset of the tasks should not cancel the add timer val taskInfos = (0 to numTasks - 1).map { i => createTaskInfo(i, i, "executor-1") } - taskInfos.tail.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + taskInfos.tail.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } assert(addTime(manager) !== NOT_SET) // Starting all remaining tasks should cancel the add timer - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfos.head)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfos.head)) assert(addTime(manager) === NOT_SET) // Start two different stages // The add timer should be canceled only if all tasks in both stages start running - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, numTasks))) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, numTasks))) assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(1, 0, info)) } assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(2, 0, info)) } assert(addTime(manager) === NOT_SET) } @@ -715,22 +721,22 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).size === 5) // Starting a task cancel the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) assert(removeTimes(manager).size === 3) assert(!removeTimes(manager).contains("executor-1")) assert(!removeTimes(manager).contains("executor-2")) // Finishing all tasks running on an executor should start the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(0, 0, "executor-1"), new TaskMetrics)) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(2, 2, "executor-2"), new TaskMetrics)) assert(removeTimes(manager).size === 4) assert(!removeTimes(manager).contains("executor-1")) // executor-1 has not finished yet assert(removeTimes(manager).contains("executor-2")) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(1, 1, "executor-1"), new TaskMetrics)) assert(removeTimes(manager).size === 5) assert(removeTimes(manager).contains("executor-1")) // executor-1 has now finished @@ -743,13 +749,13 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -757,14 +763,14 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -775,8 +781,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) @@ -788,15 +794,15 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -809,7 +815,7 @@ class ExecutorAllocationManagerSuite sc = createSparkContext(0, 100000, 0) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stage1)) assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) @@ -820,12 +826,12 @@ class ExecutorAllocationManagerSuite onExecutorAdded(manager, s"executor-$i") } assert(executorIds(manager).size === 15) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stage1)) + post(sc.listenerBus, SparkListenerStageCompleted(stage1)) adjustRequestedExecutors(manager) assert(numExecutorsTarget(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 1000))) addExecutors(manager) assert(numExecutorsTarget(manager) === 16) } @@ -842,7 +848,7 @@ class ExecutorAllocationManagerSuite // Verify whether the initial number of executors is kept with no pending tasks assert(numExecutorsTarget(manager) === 3) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) clock.advance(100L) assert(maxNumExecutorsNeeded(manager) === 2) @@ -892,7 +898,7 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo1 = createStageInfo(1, 5, localityPreferences1) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo1)) assert(localityAwareTasks(manager) === 3) assert(hostToLocalTaskCount(manager) === @@ -904,13 +910,13 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo2 = createStageInfo(2, 3, localityPreferences2) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo2)) assert(localityAwareTasks(manager) === 5) assert(hostToLocalTaskCount(manager) === Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageCompleted(stageInfo1)) assert(localityAwareTasks(manager) === 2) assert(hostToLocalTaskCount(manager) === Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) @@ -921,16 +927,16 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(maxNumExecutorsNeeded(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1))) assert(maxNumExecutorsNeeded(manager) === 1) val taskInfo = createTaskInfo(1, 1, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo)) assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. val taskEndReason = ExceptionFailure(null, null, null, null, None) - sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } @@ -942,7 +948,7 @@ class ExecutorAllocationManagerSuite // Allocation manager is reset when adding executor requests are sent without reporting back // executor added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 2) @@ -957,7 +963,7 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager) === Set.empty) // Allocation manager is reset when executors are added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) addExecutors(manager) addExecutors(manager) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 703fc1b34c387..6222e576d1ce9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -751,7 +751,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Helper functions to extract commonly used code in Fetch Failure test cases private def setupStageAbortTest(sc: SparkContext) { - sc.listenerBus.addListener(new EndListener()) + sc.listenerBus.addToSharedQueue(new EndListener()) ended = false jobResult = null } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 0afd07b851cf9..6b42775ccb0f6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -164,9 +164,9 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) - listenerBus.addListener(eventLogger) - listenerBus.postToAll(applicationStart) - listenerBus.postToAll(applicationEnd) + listenerBus.addToEventLogQueue(eventLogger) + listenerBus.post(applicationStart) + listenerBus.post(applicationEnd) listenerBus.stop() eventLogger.stop() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 995df1dd52010..d061c7845f4a6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -34,6 +34,8 @@ import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { + import LiveListenerBus._ + /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 @@ -42,18 +44,28 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext]) private val mockMetricsSystem: MetricsSystem = Mockito.mock(classOf[MetricsSystem]) + private def numDroppedEvents(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount + } + + private def queueSize(bus: LiveListenerBus): Int = { + bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() + .asInstanceOf[Int] + } + + private def eventProcessingTimeCount(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.timer(s"queue.$SHARED_QUEUE.listenerProcessingTime").getCount() + } + test("don't call sc.stop in listener") { sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc.conf) - bus.addListener(listener) - // Starting listener bus should flush all buffered events - bus.start(sc, sc.env.metricsSystem) - bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.listenerBus.addToSharedQueue(listener) + sc.listenerBus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.stop() - bus.stop() assert(listener.sparkExSeen) } @@ -61,13 +73,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addListener(counter) + bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) - assert(bus.metrics.numDroppedEvents.getCount === 0) - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.eventProcessingTime.getCount === 0) + assert(numDroppedEvents(bus) === 0) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 0) // Post five events: (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -75,7 +87,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(bus.metrics.queueSize.getValue === 5) + assert(queueSize(bus) === 5) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -83,18 +95,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.eventProcessingTime.getCount === 5) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 5) // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) - assert(bus.metrics.numEventsPosted.getCount === 5) - - // Make sure per-listener-class timers were created: - assert(bus.metrics.getTimerForListenerClass( - classOf[BasicJobCounter].asSubclass(classOf[SparkListenerInterface])).get.getCount == 5) + assert(eventProcessingTimeCount(bus) === 5) // Listener bus must not be started twice intercept[IllegalStateException] { @@ -135,7 +143,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val bus = new LiveListenerBus(new SparkConf()) val blockingListener = new BlockingListener - bus.addListener(blockingListener) + bus.addToSharedQueue(blockingListener) bus.start(mockSparkContext, mockMetricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) @@ -168,7 +176,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listenerStarted = new Semaphore(0) val listenerWait = new Semaphore(0) - bus.addListener(new SparkListener { + bus.addToSharedQueue(new SparkListener { override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() @@ -180,20 +188,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(queueSize(bus) === 0) + assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(bus.metrics.queueSize.getValue === 1) - assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(bus.metrics.queueSize.getValue === 1) - assert(bus.metrics.numDroppedEvents.getCount === 1) - + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: listenerWait.release(2) @@ -419,9 +426,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val bus = new LiveListenerBus(new SparkConf()) // Propagate events to bad listener first - bus.addListener(badListener) - bus.addListener(jobCounter1) - bus.addListener(jobCounter2) + bus.addToSharedQueue(badListener) + bus.addToSharedQueue(jobCounter1) + bus.addToSharedQueue(jobCounter2) bus.start(mockSparkContext, mockMetricsSystem) // Post events to all listeners, and wait until the queue is drained @@ -429,7 +436,6 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // The exception should be caught, and the event should be propagated to other listeners - assert(bus.listenerThreadIsAlive) assert(jobCounter1.count === 5) assert(jobCounter2.count === 5) } @@ -449,6 +455,31 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } + test("add and remove listeners to/from LiveListenerBus queues") { + val bus = new LiveListenerBus(new SparkConf(false)) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val counter3 = new BasicJobCounter() + + bus.addToSharedQueue(counter1) + bus.addToStatusQueue(counter2) + bus.addToStatusQueue(counter3) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 3) + + bus.removeListener(counter1) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + + bus.removeListener(counter2) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 1) + + bus.removeListener(counter3) + assert(bus.activeQueues().isEmpty) + assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 1cb52593e7060..79f02f2e50bbd 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.storage._ * Test various functionality in the StorageListener that supports the StorageTab. */ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { - private var bus: LiveListenerBus = _ + private var bus: SparkListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ private val memAndDisk = StorageLevel.MEMORY_AND_DISK @@ -43,7 +43,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { before { val conf = new SparkConf() - bus = new LiveListenerBus(conf) + bus = new ReplayListenerBus() storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 4207013c3f75d..07e39023c8366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) import StreamingQueryListener._ - sparkListenerBus.addListener(this) + sparkListenerBus.addToSharedQueue(this) /** * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 7202f1222d10f..ad9db308b2627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager @@ -148,7 +149,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { if (SparkSession.sqlListener.get() == null) { val listener = new SQLListener(sc.conf) if (SparkSession.sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) + sc.listenerBus.addToStatusQueue(listener) sc.ui.foreach(new SQLTab(listener, _)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f3b4ff2d1d80c..8c7418ec7ac10 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -659,8 +659,7 @@ class StreamingContext private[streaming] ( def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop StreamingContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop StreamingContext within listener bus thread.") } synchronized { // The state should always be Stopped after calling `stop()`, even if we haven't started yet diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 5fb0bd057d0f1..6a70bf7406b3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -76,7 +76,7 @@ private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) * forward them to StreamingListeners. */ def start(): Unit = { - sparkListenerBus.addListener(this) // for getting callbacks on spark events + sparkListenerBus.addToStatusQueue(this) } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 96ab5a2080b8e..5810e73f4098b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -575,8 +575,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL test("getActive and getActiveOrCreate") { require(StreamingContext.getActive().isEmpty, "context exists from before") - sc = new SparkContext(conf) - var newContextCreated = false def creatingFunc(): StreamingContext = { @@ -603,6 +601,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getActiveOrCreate should create new context and getActive should return it only // after starting the context testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = StreamingContext.getActiveOrCreate(creatingFunc _) assert(ssc != null, "no context created") assert(newContextCreated === true, "new context not created") @@ -622,6 +621,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getActiveOrCreate and getActive should return independently created context after activating testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = creatingFunc() // Create assert(StreamingContext.getActive().isEmpty, "new initialized context returned before starting") From 280ff523f4079dd9541efc95e6efcb69f9374d22 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 20 Sep 2017 00:01:21 -0700 Subject: [PATCH 30/37] [SPARK-21977] SinglePartition optimizations break certain Streaming Stateful Aggregation requirements ## What changes were proposed in this pull request? This is a bit hard to explain as there are several issues here, I'll try my best. Here are the requirements: 1. A StructuredStreaming Source that can generate empty RDDs with 0 partitions 2. A StructuredStreaming query that uses the above source, performs a stateful aggregation (mapGroupsWithState, groupBy.count, ...), and coalesce's by 1 The crux of the problem is that when a dataset has a `coalesce(1)` call, it receives a `SinglePartition` partitioning scheme. This scheme satisfies most required distributions used for aggregations such as HashAggregateExec. This causes a world of problems: Symptom 1. If the input RDD has 0 partitions, the whole lineage will receive 0 partitions, nothing will be executed, the state store will not create any delta files. When this happens, the next trigger fails, because the StateStore fails to load the delta file for the previous trigger Symptom 2. Let's say that there was data. Then in this case, if you stop your stream, and change `coalesce(1)` with `coalesce(2)`, then restart your stream, your stream will fail, because `spark.sql.shuffle.partitions - 1` number of StateStores will fail to find its delta files. To fix the issues above, we must check that the partitioning of the child of a `StatefulOperator` satisfies: If the grouping expressions are empty: a) AllTuple distribution b) Single physical partition If the grouping expressions are non empty: a) Clustered distribution b) spark.sql.shuffle.partition # of partitions whether or not `coalesce(1)` exists in the plan, and whether or not the input RDD for the trigger has any data. Once you fix the above problem by adding an Exchange to the plan, you come across the following bug: If you call `coalesce(1).groupBy().count()` on a Streaming DataFrame, and if you have a trigger with no data, `StateStoreRestoreExec` doesn't return the prior state. However, for this specific aggregation, `HashAggregateExec` after the restore returns a (0, 0) row, since we're performing a count, and there is no data. Then this data gets stored in `StateStoreSaveExec` causing the previous counts to be overwritten and lost. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #19196 from brkyvz/sa-0. --- .../streaming/IncrementalExecution.scala | 34 ++- .../execution/streaming/StreamExecution.scala | 1 + .../streaming/statefulOperators.scala | 37 +++- .../EnsureStatefulOpPartitioningSuite.scala | 132 ++++++++++++ .../spark/sql/streaming/StreamTest.scala | 16 +- .../streaming/StreamingAggregationSuite.scala | 196 +++++++++++++++++- 6 files changed, 395 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 19d95980d57d3..027222e1119c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -21,11 +21,13 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.streaming.OutputMode /** @@ -89,7 +91,7 @@ class IncrementalExecution( override def apply(plan: SparkPlan): SparkPlan = plan transform { case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, - StateStoreRestoreExec(keys2, None, child))) => + StateStoreRestoreExec(_, None, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, @@ -117,8 +119,34 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations + override def preparations: Seq[Rule[SparkPlan]] = + Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } + +object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { + // Needs to be transformUp to avoid extra shuffles + override def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case so: StatefulOperator => + val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions + val distributions = so.requiredChildDistribution + val children = so.children.zip(distributions).map { case (child, reqDistribution) => + val expectedPartitioning = reqDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) + case _ => throw new AnalysisException("Unexpected distribution expected for " + + s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + + s"$reqDistribution.") + } + if (child.outputPartitioning.guarantees(expectedPartitioning) && + child.execute().getNumPartitions == expectedPartitioning.numPartitions) { + child + } else { + ShuffleExchange(expectedPartitioning, child) + } + } + so.withNewChildren(children) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b27a59b8a34fb..18385f5fc1975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -829,6 +829,7 @@ class StreamExecution( if (streamDeathCause != null) { throw streamDeathCause } + if (!isActive) return awaitBatchLock.lock() try { noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e46356392c51b..d6566b8e6b54f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -200,11 +200,20 @@ case class StateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) - numOutputRows += 1 - row +: Option(savedState).toSeq + val hasInput = iter.hasNext + if (!hasInput && keyExpressions.isEmpty) { + // If our `keyExpressions` are empty, we're getting a global aggregation. In that case + // the `HashAggregateExec` will output a 0 value for the partial merge. We need to + // restore the value, so that we don't overwrite our state with a 0 value, but rather + // merge the 0 with existing state. + store.iterator().map(_.value) + } else { + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + numOutputRows += 1 + row +: Option(savedState).toSeq + } } } } @@ -212,6 +221,14 @@ case class StateStoreRestoreExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** @@ -351,6 +368,14 @@ case class StateStoreSaveExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** Physical operator for executing streaming Deduplicate. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala new file mode 100644 index 0000000000000..66c0263e872b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} +import org.apache.spark.sql.test.SharedSQLContext + +class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { + + import testImplicits._ + super.beforeAll() + + private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution generates Exchange with HashPartitioning", + baseDf.queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning", + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples generates Exchange with SinglePartition", + baseDf.queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples with coalesce(1) doesn't need Exchange", + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = false) + + /** + * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan + * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to + * ensure the expected partitioning. + */ + private def testEnsureStatefulOpPartitioning( + testName: String, + inputPlan: SparkPlan, + requiredDistribution: Seq[Attribute] => Distribution, + expectedPartitioning: Seq[Attribute] => Partitioning, + expectShuffle: Boolean): Unit = { + test(testName) { + val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) + val executed = executePlan(operator, OutputMode.Complete()) + if (expectShuffle) { + val exchange = executed.children.find(_.isInstanceOf[Exchange]) + if (exchange.isEmpty) { + fail(s"Was expecting an exchange but didn't get one in:\n$executed") + } + assert(exchange.get === + ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + s"Exchange didn't have expected properties:\n${exchange.get}") + } else { + assert(!executed.children.exists(_.isInstanceOf[Exchange]), + s"Unexpected exchange found in:\n$executed") + } + } + } + + /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ + private def executePlan( + p: SparkPlan, + outputMode: OutputMode = OutputMode.Append()): SparkPlan = { + val execution = new IncrementalExecution( + spark, + null, + OutputMode.Complete(), + "chk", + UUID.randomUUID(), + 0L, + OffsetSeqMetadata()) { + override lazy val sparkPlan: SparkPlan = p transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + execution.executedPlan + } +} + +/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ +case class TestStatefulOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4f8764060d922..70b39b934071a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class StartStream( trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, - additionalConfs: Map[String, String] = Map.empty) + additionalConfs: Map[String, String] = Map.empty, + checkpointLocation: String = null) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -349,13 +350,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath try { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock, additionalConfs) => + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -363,6 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) additionalConfs.foreach(pair => { val value = @@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be verify(currentStream != null || lastStream != null, "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } case a: Assert => val streamToAssert = Option(currentStream).getOrElse(lastStream) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e0979ce296c3a..995cea3b37d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -22,20 +22,24 @@ import java.util.{Locale, TimeZone} import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ -import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -object FailureSinglton { +object FailureSingleton { var firstTime = true } @@ -226,12 +230,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testQuietly("midbatch failure") { val inputData = MemoryStream[Int] - FailureSinglton.firstTime = true + FailureSingleton.firstTime = true val aggregated = inputData.toDS() .map { i => - if (i == 4 && FailureSinglton.firstTime) { - FailureSinglton.firstTime = false + if (i == 4 && FailureSingleton.firstTime) { + FailureSingleton.firstTime = false sys.error("injected failure") } @@ -381,4 +385,180 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } + + /** + * This method verifies certain properties in the SparkPlan of a streaming aggregation. + * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired + * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already + * provides the expected data distribution. + * + * The second thing it checks that the child provides the expected number of partitions. + * + * The third thing it checks that we don't add an unnecessary shuffle in-between + * `StateStoreRestoreExec` and `StateStoreSaveExec`. + */ + private def checkAggregationChain( + se: StreamExecution, + expectShuffling: Boolean, + expectedPartition: Int): Boolean = { + val executedPlan = se.lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") + if (expectShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + true + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. Therefore in our SparkPlan, we + // don't have any shuffling. However, `coalesce(1)` only guarantees that the RDD has at most 1 + // partition. Which means that if we have an input RDD with 0 partitions, nothing gets + // executed. Therefore the StateStore's don't save any delta files for a given trigger. This + // then leads to `FileNotFoundException`s in the subsequent batch. + // This isn't the only problem though. Once we introduce a shuffle before + // `StateStoreRestoreExec`, the input to the operator is an empty iterator. When performing + // `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all aggregations. If + // we fail to restore the previous state in `StateStoreRestoreExec`, we save the 0 value in + // `StateStoreSaveExec` losing all previous state. + val aggregated: Dataset[Long] = + spark.readStream.format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(1).groupBy().count().as[Long] + + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, + AddBlockData(inputSource), // create an empty trigger + CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) + } + } + + test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + + "has non-empty grouping keys") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. However, when we have + // non-empty grouping keys, in streaming, we must repartition to + // `spark.sql.shuffle.partitions`, otherwise only a single StateStore is used to process + // all keys. This may be fine, however, if the user removes the coalesce(1) or changes to + // a `coalesce(2)` for example, then the default behavior is to shuffle to + // `spark.sql.shuffle.partitions` many StateStores. When this happens, all StateStore's + // except 1 will be missing their previous delta files, which causes the stream to fail + // with FileNotFoundException. + def createDf(partitions: Int): Dataset[(Long, Long)] = { + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] + } + + testStream(createDf(1), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddBlockData(inputSource, Seq(1)), + CheckLastBatch((0L, 1L)), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain( + se, + expectShuffling = true, + spark.sessionState.conf.numShufflePartitions) + }, + StopStream + ) + + testStream(createDf(2), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + Execute(se => se.processAllAvailable()), + AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), + CheckLastBatch((0L, 4L)), + AssertOnQuery("Verify no exchange added") { se => + checkAggregationChain( + se, + expectShuffling = false, + spark.sessionState.conf.numShufflePartitions) + }, + AddBlockData(inputSource), + CheckLastBatch((0L, 4L)), + StopStream + ) + } + } + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + source.addBlocks(data: _*) + (source, LongOffset(source.counter)) + } + } + + /** + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. + */ + class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { + dataBlocks.foreach { data => + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil + counter += 1 + } + counter += 1 + } + + override def getOffset: Option[Offset] = synchronized { + if (counter == 0) None else Some(LongOffset(counter)) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) + .map(i => InternalRow(i)) // we don't really care about the values in this test + blocks = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = { + blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + } + } } From 964aef5879fec64797d979daa953118e317e6e47 Mon Sep 17 00:00:00 2001 From: foxish Date: Thu, 14 Sep 2017 20:10:24 -0700 Subject: [PATCH 31/37] Spark on Kubernetes - basic scheduler backend --- resource-managers/kubernetes/core/pom.xml | 137 +++++ .../kubernetes/ConfigurationUtils.scala | 82 +++ .../kubernetes/OptionRequirements.scala | 40 ++ .../SparkKubernetesClientFactory.scala | 103 ++++ .../spark/deploy/kubernetes/config.scala | 551 ++++++++++++++++++ .../spark/deploy/kubernetes/constants.scala | 105 ++++ .../kubernetes/ExecutorPodFactory.scala | 230 ++++++++ .../kubernetes/KubernetesClusterManager.scala | 69 +++ .../KubernetesClusterSchedulerBackend.scala | 445 ++++++++++++++ .../core/src/test/resources/log4j.properties | 31 + .../kubernetes/ExecutorPodFactorySuite.scala | 136 +++++ ...bernetesClusterSchedulerBackendSuite.scala | 377 ++++++++++++ 12 files changed, 2306 insertions(+) create mode 100644 resource-managers/kubernetes/core/pom.xml create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala create mode 100644 resource-managers/kubernetes/core/src/test/resources/log4j.properties create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml new file mode 100644 index 0000000000000..a4b18c527c969 --- /dev/null +++ b/resource-managers/kubernetes/core/pom.xml @@ -0,0 +1,137 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes_2.11 + jar + Spark Project Kubernetes + + kubernetes + 2.2.13 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + + + org.glassfish.jersey.containers + jersey-container-servlet + + + org.glassfish.jersey.media + jersey-media-multipart + + + com.squareup.retrofit2 + retrofit + + + com.squareup.retrofit2 + converter-jackson + + + com.squareup.retrofit2 + converter-scalars + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + javax.ws.rs + javax.ws.rs-api + + + + com.google.guava + guava + + + + + org.bouncycastle + bcpkix-jdk15on + + + org.bouncycastle + bcprov-jdk15on + + + org.mockito + mockito-core + test + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala new file mode 100644 index 0000000000000..1a008c236d00f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.kubernetes + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.OptionalConfigEntry + +object ConfigurationUtils extends Logging { + def parseKeyValuePairs( + maybeKeyValues: Option[String], + configKey: String, + keyValueType: String): Map[String, String] = { + + maybeKeyValues.map(keyValues => { + keyValues.split(",").map(_.trim).filterNot(_.isEmpty).map(keyValue => { + keyValue.split("=", 2).toSeq match { + case Seq(k, v) => + (k, v) + case _ => + throw new SparkException(s"Custom $keyValueType set by $configKey must be a" + + s" comma-separated list of key-value pairs, with format =." + + s" Got value: $keyValue. All values: $keyValues") + } + }).toMap + }).getOrElse(Map.empty[String, String]) + } + + def combinePrefixedKeyValuePairsWithDeprecatedConf( + sparkConf: SparkConf, + prefix: String, + deprecatedConf: OptionalConfigEntry[String], + configType: String): Map[String, String] = { + val deprecatedKeyValuePairsString = sparkConf.get(deprecatedConf) + deprecatedKeyValuePairsString.foreach { _ => + logWarning(s"Configuration with key ${deprecatedConf.key} is deprecated. Use" + + s" configurations with prefix $prefix instead.") + } + val fromDeprecated = parseKeyValuePairs( + deprecatedKeyValuePairsString, + deprecatedConf.key, + configType) + val fromPrefix = sparkConf.getAllWithPrefix(prefix) + val combined = fromDeprecated.toSeq ++ fromPrefix + combined.groupBy(_._1).foreach { + case (key, values) => + require(values.size == 1, + s"Cannot have multiple values for a given $configType key, got key $key with" + + s" values $values") + } + combined.toMap + } + + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String, + configType: String): Map[String, String] = { + val fromPrefix = sparkConf.getAllWithPrefix(prefix) + fromPrefix.groupBy(_._1).foreach { + case (key, values) => + require(values.size == 1, + s"Cannot have multiple values for a given $configType key, got key $key with" + + s" values $values") + } + fromPrefix.toMap + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala new file mode 100644 index 0000000000000..eda43de0a9a5b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +private[spark] object OptionRequirements { + + def requireBothOrNeitherDefined( + opt1: Option[_], + opt2: Option[_], + errMessageWhenFirstIsMissing: String, + errMessageWhenSecondIsMissing: String): Unit = { + requireSecondIfFirstIsDefined(opt1, opt2, errMessageWhenSecondIsMissing) + requireSecondIfFirstIsDefined(opt2, opt1, errMessageWhenFirstIsMissing) + } + + def requireSecondIfFirstIsDefined( + opt1: Option[_], opt2: Option[_], errMessageWhenSecondIsMissing: String): Unit = { + opt1.foreach { _ => + require(opt2.isDefined, errMessageWhenSecondIsMissing) + } + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala new file mode 100644 index 0000000000000..d2729a2db2fa0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import io.fabric8.kubernetes.client.{Config, ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.utils.HttpClientUtils +import okhttp3.Dispatcher + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to + * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL + * options for different components. + */ +private[spark] object SparkKubernetesClientFactory { + + def createKubernetesClient( + master: String, + namespace: Option[String], + kubernetesAuthConfPrefix: String, + sparkConf: SparkConf, + maybeServiceAccountToken: Option[File], + maybeServiceAccountCaCert: Option[File]): KubernetesClient = { + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" + val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" + val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) + .map(new File(_)) + .orElse(maybeServiceAccountToken) + val oauthTokenValue = sparkConf.getOption(oauthTokenConf) + OptionRequirements.requireNandDefined( + oauthTokenFile, + oauthTokenValue, + s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a" + + s" value $oauthTokenConf.") + + val caCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX") + .orElse(maybeServiceAccountCaCert.map(_.getAbsolutePath)) + val clientKeyFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") + val clientCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + val dispatcher = new Dispatcher( + ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + val config = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(master) + .withWebsocketPingInterval(0) + .withOption(oauthTokenValue) { + (token, configBuilder) => configBuilder.withOauthToken(token) + }.withOption(oauthTokenFile) { + (file, configBuilder) => + configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + }.withOption(caCertFile) { + (file, configBuilder) => configBuilder.withCaCertFile(file) + }.withOption(clientKeyFile) { + (file, configBuilder) => configBuilder.withClientKeyFile(file) + }.withOption(clientCertFile) { + (file, configBuilder) => configBuilder.withClientCertFile(file) + }.withOption(namespace) { + (ns, configBuilder) => configBuilder.withNamespace(ns) + }.build() + val baseHttpClient = HttpClientUtils.createHttpClient(config) + val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() + .dispatcher(dispatcher) + .build() + new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + } + + private implicit class OptionConfigurableConfigBuilder(configBuilder: ConfigBuilder) { + + def withOption[T] + (option: Option[T]) + (configurator: ((T, ConfigBuilder) => ConfigBuilder)): OptionConfigurableConfigBuilder = { + new OptionConfigurableConfigBuilder(option.map { opt => + configurator(opt, configBuilder) + }.getOrElse(configBuilder)) + } + + def build(): Config = configBuilder.build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala new file mode 100644 index 0000000000000..9dfd13e1817f8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala @@ -0,0 +1,551 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +import java.util.concurrent.TimeUnit + +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +package object config extends Logging { + + private[spark] val KUBERNETES_NAMESPACE = + ConfigBuilder("spark.kubernetes.namespace") + .doc("The namespace that will be used for running the driver and executor pods. When using" + + " spark-submit in cluster mode, this can also be passed to spark-submit via the" + + " --kubernetes-namespace command line argument.") + .stringConf + .createWithDefault("default") + + private[spark] val DRIVER_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.driver.docker.image") + .doc("Docker image to use for the driver. Specify this using the standard Docker tag format.") + .stringConf + .createWithDefault(s"spark-driver:$sparkVersion") + + private[spark] val EXECUTOR_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.docker.image") + .doc("Docker image to use for the executors. Specify this using the standard Docker tag" + + " format.") + .stringConf + .createWithDefault(s"spark-executor:$sparkVersion") + + private[spark] val DOCKER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + .doc("Docker image pull policy when pulling any docker image in Kubernetes integration") + .stringConf + .createWithDefault("IfNotPresent") + + private[spark] val APISERVER_AUTH_SUBMISSION_CONF_PREFIX = + "spark.kubernetes.authenticate.submission" + private[spark] val APISERVER_AUTH_DRIVER_CONF_PREFIX = + "spark.kubernetes.authenticate.driver" + private[spark] val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + "spark.kubernetes.authenticate.driver.mounted" + private[spark] val APISERVER_AUTH_RESOURCE_STAGING_SERVER_CONF_PREFIX = + "spark.kubernetes.authenticate.resourceStagingServer" + private[spark] val APISERVER_AUTH_SHUFFLE_SERVICE_CONF_PREFIX = + "spark.kubernetes.authenticate.shuffleService" + private[spark] val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" + private[spark] val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" + private[spark] val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" + private[spark] val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" + private[spark] val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" + + private[spark] val RESOURCE_STAGING_SERVER_USE_SERVICE_ACCOUNT_CREDENTIALS = + ConfigBuilder( + s"$APISERVER_AUTH_RESOURCE_STAGING_SERVER_CONF_PREFIX.useServiceAccountCredentials") + .doc("Use a service account token and CA certificate in the resource staging server to" + + " watch the API server's objects.") + .booleanConf + .createWithDefault(true) + + private[spark] val KUBERNETES_SERVICE_ACCOUNT_NAME = + ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + .doc("Service account that is used when running the driver pod. The driver pod uses" + + " this service account when requesting executor pods from the API server. If specific" + + " credentials are given for the driver pod to use, the driver will favor" + + " using those credentials instead.") + .stringConf + .createOptional + + private[spark] val SPARK_SHUFFLE_SERVICE_HOST = + ConfigBuilder("spark.shuffle.service.host") + .doc("Host for Spark Shuffle Service") + .internal() + .stringConf + .createOptional + + // Note that while we set a default for this when we start up the + // scheduler, the specific default value is dynamically determined + // based on the executor memory. + private[spark] val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.executor.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This" + + " is memory that accounts for things like VM overheads, interned strings, other native" + + " overheads, etc. This tends to grow with the executor size. (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val KUBERNETES_DRIVER_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.driver.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated for the driver and the" + + " driver submission server. This is memory that accounts for things like VM overheads," + + " interned strings, other native overheads, etc. This tends to grow with the driver's" + + " memory size (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." + private[spark] val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + private[spark] val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + private[spark] val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + private[spark] val KUBERNETES_DRIVER_LABELS = + ConfigBuilder("spark.kubernetes.driver.labels") + .doc("Custom labels that will be added to the driver pod. This should be a comma-separated" + + " list of label key-value pairs, where each label is in the format key=value. Note that" + + " Spark also adds its own labels to the driver pod for bookkeeping purposes.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + + private[spark] val KUBERNETES_DRIVER_ANNOTATIONS = + ConfigBuilder("spark.kubernetes.driver.annotations") + .doc("Custom annotations that will be added to the driver pod. This should be a" + + " comma-separated list of annotation key-value pairs, where each annotation is in the" + + " format key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_LABELS = + ConfigBuilder("spark.kubernetes.executor.labels") + .doc("Custom labels that will be added to the executor pods. This should be a" + + " comma-separated list of label key-value pairs, where each label is in the format" + + " key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_ANNOTATIONS = + ConfigBuilder("spark.kubernetes.executor.annotations") + .doc("Custom annotations that will be added to the executor pods. This should be a" + + " comma-separated list of annotation key-value pairs, where each annotation is in the" + + " format key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + private[spark] val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + + private[spark] val KUBERNETES_DRIVER_POD_NAME = + ConfigBuilder("spark.kubernetes.driver.pod.name") + .doc("Name of the driver pod.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") + .internal() + .stringConf + .createWithDefault("spark") + + private[spark] val KUBERNETES_SHUFFLE_NAMESPACE = + ConfigBuilder("spark.kubernetes.shuffle.namespace") + .doc("Namespace of the shuffle service") + .stringConf + .createWithDefault("default") + + private[spark] val KUBERNETES_SHUFFLE_SVC_IP = + ConfigBuilder("spark.kubernetes.shuffle.ip") + .doc("This setting is for debugging only. Setting this " + + "allows overriding the IP that the executor thinks its colocated " + + "shuffle service is on") + .stringConf + .createOptional + + private[spark] val KUBERNETES_SHUFFLE_LABELS = + ConfigBuilder("spark.kubernetes.shuffle.labels") + .doc("Labels to identify the shuffle service") + .stringConf + .createOptional + + private[spark] val KUBERNETES_SHUFFLE_DIR = + ConfigBuilder("spark.kubernetes.shuffle.dir") + .doc("Path to the shared shuffle directories.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_SHUFFLE_APISERVER_URI = + ConfigBuilder("spark.kubernetes.shuffle.apiServer.url") + .doc("URL to the Kubernetes API server that the shuffle service will monitor for Spark pods.") + .stringConf + .createWithDefault(KUBERNETES_MASTER_INTERNAL_URL) + + private[spark] val KUBERNETES_SHUFFLE_USE_SERVICE_ACCOUNT_CREDENTIALS = + ConfigBuilder(s"$APISERVER_AUTH_SHUFFLE_SERVICE_CONF_PREFIX.useServiceAccountCredentials") + .doc("Whether or not to use service account credentials when contacting the API server from" + + " the shuffle service.") + .booleanConf + .createWithDefault(true) + + private[spark] val KUBERNETES_ALLOCATION_BATCH_SIZE = + ConfigBuilder("spark.kubernetes.allocation.batch.size") + .doc("Number of pods to launch at once in each round of dynamic allocation. ") + .intConf + .createWithDefault(5) + + private[spark] val KUBERNETES_ALLOCATION_BATCH_DELAY = + ConfigBuilder("spark.kubernetes.allocation.batch.delay") + .doc("Number of seconds to wait between each round of executor allocation. ") + .longConf + .createWithDefault(1) + + private[spark] val WAIT_FOR_APP_COMPLETION = + ConfigBuilder("spark.kubernetes.submission.waitAppCompletion") + .doc("In cluster mode, whether to wait for the application to finish before exiting the" + + " launcher process.") + .booleanConf + .createWithDefault(true) + + private[spark] val REPORT_INTERVAL = + ConfigBuilder("spark.kubernetes.report.interval") + .doc("Interval between reports of the current app status in cluster mode.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1s") + + // Spark resource staging server. + private[spark] val RESOURCE_STAGING_SERVER_API_SERVER_URL = + ConfigBuilder("spark.kubernetes.resourceStagingServer.apiServer.url") + .doc("URL for the Kubernetes API server. The resource staging server monitors the API" + + " server to check when pods no longer are using mounted resources. Note that this isn't" + + " to be used in Spark applications, as the API server URL should be set via spark.master.") + .stringConf + .createWithDefault(KUBERNETES_MASTER_INTERNAL_URL) + + private[spark] val RESOURCE_STAGING_SERVER_API_SERVER_CA_CERT_FILE = + ConfigBuilder("spark.kubernetes.resourceStagingServer.apiServer.caCertFile") + .doc("CA certificate for the resource staging server to use when contacting the Kubernetes" + + " API server over TLS.") + .stringConf + .createOptional + + private[spark] val RESOURCE_STAGING_SERVER_PORT = + ConfigBuilder("spark.kubernetes.resourceStagingServer.port") + .doc("Port for the Kubernetes resource staging server to listen on.") + .intConf + .createWithDefault(10000) + + private[spark] val RESOURCE_STAGING_SERVER_INITIAL_ACCESS_EXPIRATION_TIMEOUT = + ConfigBuilder("spark.kubernetes.resourceStagingServer.initialAccessExpirationTimeout") + .doc("The resource staging server will wait for any resource bundle to be accessed for a" + + " first time for this period. If this timeout expires before the resources are accessed" + + " the first time, the resources are cleaned up under the assumption that the dependents" + + " of the given resource bundle failed to launch at all.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("30m") + + private[spark] val RESOURCE_STAGING_SERVER_KEY_PEM = + ConfigBuilder("spark.ssl.kubernetes.resourceStagingServer.keyPem") + .doc("Key PEM file to use when having the Kubernetes dependency server listen on TLS.") + .stringConf + .createOptional + + private[spark] val RESOURCE_STAGING_SERVER_SSL_NAMESPACE = "kubernetes.resourceStagingServer" + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE = + "kubernetes.resourceStagingServer.internal" + private[spark] val RESOURCE_STAGING_SERVER_CERT_PEM = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.serverCertPem") + .doc("Certificate PEM file to use when having the resource staging server" + + " listen on TLS.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_CLIENT_CERT_PEM = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.clientCertPem") + .doc("Certificate PEM file to use when the client contacts the resource staging server." + + " This must strictly be a path to a file on the submitting machine's disk.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_CLIENT_CERT_PEM = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.clientCertPem") + .doc("Certificate PEM file to use when the init-container contacts the resource staging" + + " server. If this is not provided, it defaults to the value of" + + " spark.ssl.kubernetes.resourceStagingServer.clientCertPem. This can be a URI with" + + " a scheme of local:// which denotes that the file is pre-mounted on the init-container's" + + " disk. A uri without a scheme or a scheme of file:// will result in this file being" + + " mounted from the submitting machine's disk as a secret into the pods.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_KEYSTORE_PASSWORD_FILE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.keyStorePasswordFile") + .doc("File containing the keystore password for the Kubernetes resource staging server.") + .stringConf + .createOptional + + private[spark] val RESOURCE_STAGING_SERVER_KEYSTORE_KEY_PASSWORD_FILE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.keyPasswordFile") + .doc("File containing the key password for the Kubernetes resource staging server.") + .stringConf + .createOptional + + private[spark] val RESOURCE_STAGING_SERVER_SSL_ENABLED = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.enabled") + .doc("Whether or not to use SSL when communicating with the resource staging server.") + .booleanConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_SSL_ENABLED = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.enabled") + .doc("Whether or not to use SSL when communicating with the resource staging server from" + + " the init-container. If this is not provided, defaults to" + + " the value of spark.ssl.kubernetes.resourceStagingServer.enabled") + .booleanConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_FILE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStore") + .doc("File containing the trustStore to communicate with the Kubernetes dependency server." + + " This must strictly be a path on the submitting machine's disk.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_FILE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStore") + .doc("File containing the trustStore to communicate with the Kubernetes dependency server" + + " from the init-container. If this is not provided, defaults to the value of" + + " spark.ssl.kubernetes.resourceStagingServer.trustStore. This can be a URI with a scheme" + + " of local:// indicating that the trustStore is pre-mounted on the init-container's" + + " disk. If no scheme, or a scheme of file:// is provided, this file is mounted from the" + + " submitting machine's disk as a Kubernetes secret into the pods.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_PASSWORD = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStorePassword") + .doc("Password for the trustStore for communicating to the dependency server.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_PASSWORD = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStorePassword") + .doc("Password for the trustStore for communicating to the dependency server from the" + + " init-container. If this is not provided, defaults to" + + " spark.ssl.kubernetes.resourceStagingServer.trustStorePassword.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_TYPE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStoreType") + .doc("Type of trustStore for communicating with the dependency server.") + .stringConf + .createOptional + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_TYPE = + ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStoreType") + .doc("Type of trustStore for communicating with the dependency server from the" + + " init-container. If this is not provided, defaults to" + + " spark.ssl.kubernetes.resourceStagingServer.trustStoreType") + .stringConf + .createOptional + + // Driver and Init-Container parameters + private[spark] val RESOURCE_STAGING_SERVER_URI = + ConfigBuilder("spark.kubernetes.resourceStagingServer.uri") + .doc("Base URI for the Spark resource staging server.") + .stringConf + .createOptional + + private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_URI = + ConfigBuilder("spark.kubernetes.resourceStagingServer.internal.uri") + .doc("Base URI for the Spark resource staging server when the init-containers access it for" + + " downloading resources. If this is not provided, it defaults to the value provided in" + + " spark.kubernetes.resourceStagingServer.uri, the URI that the submission client uses to" + + " upload the resources from outside the cluster.") + .stringConf + .createOptional + + private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_RESOURCE_IDENTIFIER = + ConfigBuilder("spark.kubernetes.initcontainer.downloadJarsResourceIdentifier") + .doc("Identifier for the jars tarball that was uploaded to the staging service.") + .internal() + .stringConf + .createOptional + + private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_SECRET_LOCATION = + ConfigBuilder("spark.kubernetes.initcontainer.downloadJarsSecretLocation") + .doc("Location of the application secret to use when the init-container contacts the" + + " resource staging server to download jars.") + .internal() + .stringConf + .createWithDefault(s"$INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH/" + + s"$INIT_CONTAINER_SUBMITTED_JARS_SECRET_KEY") + + private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_RESOURCE_IDENTIFIER = + ConfigBuilder("spark.kubernetes.initcontainer.downloadFilesResourceIdentifier") + .doc("Identifier for the files tarball that was uploaded to the staging service.") + .internal() + .stringConf + .createOptional + + private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_SECRET_LOCATION = + ConfigBuilder("spark.kubernetes.initcontainer.downloadFilesSecretLocation") + .doc("Location of the application secret to use when the init-container contacts the" + + " resource staging server to download files.") + .internal() + .stringConf + .createWithDefault( + s"$INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH/$INIT_CONTAINER_SUBMITTED_FILES_SECRET_KEY") + + private[spark] val INIT_CONTAINER_REMOTE_JARS = + ConfigBuilder("spark.kubernetes.initcontainer.remoteJars") + .doc("Comma-separated list of jar URIs to download in the init-container. This is" + + " calculated from spark.jars.") + .internal() + .stringConf + .createOptional + + private[spark] val INIT_CONTAINER_REMOTE_FILES = + ConfigBuilder("spark.kubernetes.initcontainer.remoteFiles") + .doc("Comma-separated list of file URIs to download in the init-container. This is" + + " calculated from spark.files.") + .internal() + .stringConf + .createOptional + + private[spark] val INIT_CONTAINER_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.initcontainer.docker.image") + .doc("Image for the driver and executor's init-container that downloads dependencies.") + .stringConf + .createWithDefault(s"spark-init:$sparkVersion") + + private[spark] val INIT_CONTAINER_JARS_DOWNLOAD_LOCATION = + ConfigBuilder("spark.kubernetes.mountdependencies.jarsDownloadDir") + .doc("Location to download jars to in the driver and executors. When using" + + " spark-submit, this directory must be empty and will be mounted as an empty directory" + + " volume on the driver and executor pod.") + .stringConf + .createWithDefault("/var/spark-data/spark-jars") + + private[spark] val INIT_CONTAINER_FILES_DOWNLOAD_LOCATION = + ConfigBuilder("spark.kubernetes.mountdependencies.filesDownloadDir") + .doc("Location to download files to in the driver and executors. When using" + + " spark-submit, this directory must be empty and will be mounted as an empty directory" + + " volume on the driver and executor pods.") + .stringConf + .createWithDefault("/var/spark-data/spark-files") + + private[spark] val INIT_CONTAINER_MOUNT_TIMEOUT = + ConfigBuilder("spark.kubernetes.mountdependencies.mountTimeout") + .doc("Timeout before aborting the attempt to download and unpack local dependencies from" + + " remote locations and the resource staging server when initializing the driver and" + + " executor pods.") + .timeConf(TimeUnit.MINUTES) + .createWithDefault(5) + + private[spark] val EXECUTOR_SUBMITTED_SMALL_FILES_SECRET = + ConfigBuilder("spark.kubernetes.mountdependencies.smallfiles.executor.secretName") + .doc("Name of the secret that should be mounted into the executor containers for" + + " distributing submitted small files without the resource staging server.") + .internal() + .stringConf + .createOptional + + private[spark] val EXECUTOR_SUBMITTED_SMALL_FILES_SECRET_MOUNT_PATH = + ConfigBuilder("spark.kubernetes.mountdependencies.smallfiles.executor.secretMountPath") + .doc(s"Mount path in the executors for the secret given by" + + s" ${EXECUTOR_SUBMITTED_SMALL_FILES_SECRET.key}") + .internal() + .stringConf + .createOptional + + private[spark] val EXECUTOR_INIT_CONTAINER_CONFIG_MAP = + ConfigBuilder("spark.kubernetes.initcontainer.executor.configmapname") + .doc("Name of the config map to use in the init-container that retrieves submitted files" + + " for the executor.") + .internal() + .stringConf + .createOptional + + private[spark] val EXECUTOR_INIT_CONTAINER_CONFIG_MAP_KEY = + ConfigBuilder("spark.kubernetes.initcontainer.executor.configmapkey") + .doc("Key for the entry in the init container config map for submitted files that" + + " corresponds to the properties for this init-container.") + .internal() + .stringConf + .createOptional + + private[spark] val EXECUTOR_INIT_CONTAINER_SECRET = + ConfigBuilder("spark.kubernetes.initcontainer.executor.stagingServerSecret.name") + .doc("Name of the secret to mount into the init-container that retrieves submitted files.") + .internal() + .stringConf + .createOptional + + private[spark] val EXECUTOR_INIT_CONTAINER_SECRET_MOUNT_DIR = + ConfigBuilder("spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir") + .doc("Directory to mount the resource staging server secrets into for the executor" + + " init-containers. This must be exactly the same as the directory that the submission" + + " client mounted the secret into because the config map's properties specify the" + + " secret location as to be the same between the driver init-container and the executor" + + " init-container. Thus the submission client will always set this and the driver will" + + " never rely on a constant or convention, in order to protect against cases where the" + + " submission client has a different version from the driver itself, and hence might" + + " have different constants loaded in constants.scala.") + .internal() + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.driver.limit.cores") + .doc("Specify the hard cpu limit for the driver pod") + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED = + ConfigBuilder("spark.kubernetes.driver.hdfslocality.clusterNodeNameDNSLookup.enabled") + .doc("Whether or not HDFS locality support code should look up DNS for full hostnames of" + + " cluster nodes. In some K8s clusters, notably GKE, cluster node names are short" + + " hostnames, and so comparing them against HDFS datanode hostnames always fail. To fix," + + " enable this flag. This is disabled by default because DNS lookup can be expensive." + + " The driver can slow down and fail to respond to executor heartbeats in time." + + " If enabling this flag, make sure your DNS server has enough capacity" + + " for the workload.") + .internal() + .booleanConf + .createWithDefault(false) + + private[spark] val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for a single executor pod") + .stringConf + .createOptional + + private[spark] val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." + + private[spark] def resolveK8sMaster(rawMasterString: String): String = { + if (!rawMasterString.startsWith("k8s://")) { + throw new IllegalArgumentException("Master URL should start with k8s:// in Kubernetes mode.") + } + val masterWithoutK8sPrefix = rawMasterString.replaceFirst("k8s://", "") + if (masterWithoutK8sPrefix.startsWith("http://") + || masterWithoutK8sPrefix.startsWith("https://")) { + masterWithoutK8sPrefix + } else { + val resolvedURL = s"https://$masterWithoutK8sPrefix" + logDebug(s"No scheme specified for kubernetes master URL, so defaulting to https. Resolved" + + s" URL is $resolvedURL") + resolvedURL + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala new file mode 100644 index 0000000000000..0a2bc46249f3a --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +package object constants { + // Labels + private[spark] val SPARK_DRIVER_LABEL = "spark-driver" + private[spark] val SPARK_APP_ID_LABEL = "spark-app-selector" + private[spark] val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" + private[spark] val SPARK_ROLE_LABEL = "spark-role" + private[spark] val SPARK_POD_DRIVER_ROLE = "driver" + private[spark] val SPARK_POD_EXECUTOR_ROLE = "executor" + private[spark] val SPARK_APP_NAME_ANNOTATION = "spark-app-name" + + // Credentials secrets + private[spark] val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = + "/mnt/secrets/spark-kubernetes-credentials" + private[spark] val DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME = "ca-cert" + private[spark] val DRIVER_CREDENTIALS_CA_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME = "client-key" + private[spark] val DRIVER_CREDENTIALS_CLIENT_KEY_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME = "client-cert" + private[spark] val DRIVER_CREDENTIALS_CLIENT_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME = "oauth-token" + private[spark] val DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_SECRET_VOLUME_NAME = "kubernetes-credentials" + + // Default and fixed ports + private[spark] val DEFAULT_DRIVER_PORT = 7078 + private[spark] val DEFAULT_BLOCKMANAGER_PORT = 7079 + private[spark] val DEFAULT_UI_PORT = 4040 + private[spark] val BLOCK_MANAGER_PORT_NAME = "blockmanager" + private[spark] val DRIVER_PORT_NAME = "driver-rpc-port" + private[spark] val EXECUTOR_PORT_NAME = "executor" + + // Environment Variables + private[spark] val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" + private[spark] val ENV_DRIVER_URL = "SPARK_DRIVER_URL" + private[spark] val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" + private[spark] val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" + private[spark] val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" + private[spark] val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" + private[spark] val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" + private[spark] val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" + private[spark] val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + private[spark] val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" + private[spark] val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" + private[spark] val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" + private[spark] val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" + private[spark] val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" + private[spark] val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" + private[spark] val ENV_PYSPARK_FILES = "PYSPARK_FILES" + private[spark] val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + private[spark] val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + private[spark] val ENV_MOUNTED_FILES_FROM_SECRET_DIR = "SPARK_MOUNTED_FILES_FROM_SECRET_DIR" + + // Bootstrapping dependencies with the init-container + private[spark] val INIT_CONTAINER_ANNOTATION = "pod.beta.kubernetes.io/init-containers" + private[spark] val INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH = + "/mnt/secrets/spark-init" + private[spark] val INIT_CONTAINER_SUBMITTED_JARS_SECRET_KEY = + "downloadSubmittedJarsSecret" + private[spark] val INIT_CONTAINER_SUBMITTED_FILES_SECRET_KEY = + "downloadSubmittedFilesSecret" + private[spark] val INIT_CONTAINER_STAGING_SERVER_TRUSTSTORE_SECRET_KEY = "trustStore" + private[spark] val INIT_CONTAINER_STAGING_SERVER_CLIENT_CERT_SECRET_KEY = "ssl-certificate" + private[spark] val INIT_CONTAINER_CONFIG_MAP_KEY = "download-submitted-files" + private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" + private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files" + private[spark] val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" + private[spark] val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" + private[spark] val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" + private[spark] val INIT_CONTAINER_PROPERTIES_FILE_PATH = + s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" + private[spark] val DEFAULT_SHUFFLE_MOUNT_NAME = "shuffle" + private[spark] val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + + // Bootstrapping dependencies via a secret + private[spark] val MOUNTED_SMALL_FILES_SECRET_MOUNT_PATH = "/etc/spark-submitted-files" + + // Miscellaneous + private[spark] val ANNOTATION_EXECUTOR_NODE_AFFINITY = "scheduler.alpha.kubernetes.io/affinity" + private[spark] val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" + private[spark] val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + private[spark] val MEMORY_OVERHEAD_FACTOR = 0.10 + private[spark] val MEMORY_OVERHEAD_MIN_MIB = 384L +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala new file mode 100644 index 0000000000000..8d493606d1f60 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, Pod, PodBuilder, QuantityBuilder} +import org.apache.commons.io.FilenameUtils + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.kubernetes.ConfigurationUtils +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.util.Utils + +// Configures executor pods. Construct one of these with a SparkConf to set up properties that are +// common across all executors. Then, pass in dynamic parameters into createExecutorPod. +private[spark] trait ExecutorPodFactory { + def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod +} + +private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) + extends ExecutorPodFactory { + + import ExecutorPodFactoryImpl._ + + private val executorExtraClasspath = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + private val executorJarsDownloadDir = sparkConf.get(INIT_CONTAINER_JARS_DOWNLOAD_LOCATION) + + private val executorLabels = ConfigurationUtils.combinePrefixedKeyValuePairsWithDeprecatedConf( + sparkConf, + KUBERNETES_EXECUTOR_LABEL_PREFIX, + KUBERNETES_EXECUTOR_LABELS, + "executor label") + require( + !executorLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + s" Spark.") + + private val executorAnnotations = + ConfigurationUtils.combinePrefixedKeyValuePairsWithDeprecatedConf( + sparkConf, + KUBERNETES_EXECUTOR_ANNOTATION_PREFIX, + KUBERNETES_EXECUTOR_ANNOTATIONS, + "executor annotation") + private val nodeSelector = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_NODE_SELECTOR_PREFIX, + "node selector") + + private val executorDockerImage = sparkConf.get(EXECUTOR_DOCKER_IMAGE) + private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val executorPort = sparkConf.getInt("spark.executor.port", DEFAULT_STATIC_PORT) + private val blockmanagerPort = sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + private val kubernetesDriverPodName = sparkConf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + + private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryString = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_MEMORY.key, + org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = sparkConf + .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = sparkConf.getDouble("spark.executor.cores", 1d) + private val executorLimitCores = sparkConf.getOption(KUBERNETES_EXECUTOR_LIMIT_CORES.key) + + override def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod = { + val name = s"$executorPodNamePrefix-exec-$executorId" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId and applicationId + val hostname = name.substring(Math.max(0, name.length - 63)) + val resolvedExecutorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> applicationId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorLabels + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryMiB}Mi") + .build() + val executorMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCores.toString) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = sparkConf + .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_EXECUTOR_PORT, executorPort.toString), + (ENV_DRIVER_URL, driverUrl), + // Executor backend expects integral value for executor cores, so round it up to an int. + (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, applicationId), + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (EXECUTOR_PORT_NAME, executorPort), + (BLOCK_MANAGER_PORT_NAME, blockmanagerPort)) + .map(port => { + new ContainerPortBuilder() + .withName(port._1) + .withContainerPort(port._2) + .build() + }) + + val executorContainer = new ContainerBuilder() + .withName(s"executor") + .withImage(executorDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryLimitQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .build() + + val executorPod = new PodBuilder() + .withNewMetadata() + .withName(name) + .withLabels(resolvedExecutorLabels.asJava) + .withAnnotations(executorAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val containerWithExecutorLimitCores = executorLimitCores.map { + limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + + new PodBuilder(executorPod) + .editSpec() + .addToContainers(containerWithExecutorLimitCores) + .endSpec() + .build() + } +} + +private object ExecutorPodFactoryImpl { + private val DEFAULT_STATIC_PORT = 10000 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala new file mode 100644 index 0000000000000..6666c0a091f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.io.File + +import io.fabric8.kubernetes.client.Config + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.kubernetes.{ConfigurationUtils, SparkKubernetesClientFactory} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.util.{ThreadUtils, Utils} + +private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { + + override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") + + override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + new TaskSchedulerImpl(sc) + } + + override def createSchedulerBackend(sc: SparkContext, masterURL: String, scheduler: TaskScheduler) + : SchedulerBackend = { + val sparkConf = sc.getConf + + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( + KUBERNETES_MASTER_INTERNAL_URL, + Some(sparkConf.get(KUBERNETES_NAMESPACE)), + APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + sparkConf, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + + val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") + new KubernetesClusterSchedulerBackend( + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala new file mode 100644 index 0000000000000..42a3e3fd50492 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.io.Closeable +import java.net.InetAddress +import java.util.Collections +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import scala.collection.{concurrent, mutable} +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpointAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RetrieveSparkAppConfig, SparkAppConfig} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class KubernetesClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + executorPodFactory: ExecutorPodFactory, + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + import KubernetesClusterSchedulerBackend._ + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + private val RUNNING_EXECUTOR_PODS_LOCK = new Object + // Indexed by executor IDs and guarded by RUNNING_EXECUTOR_PODS_LOCK. + private val runningExecutorsToPods = new mutable.HashMap[String, Pod] + // Indexed by executor pod names and guarded by RUNNING_EXECUTOR_PODS_LOCK. + private val runningPodsToExecutors = new mutable.HashMap[String, String] + // TODO(varun): Get rid of this lock object by my making the underlying map a concurrent hash map. + private val EXECUTOR_PODS_BY_IPS_LOCK = new Object + // Indexed by executor IP addrs and guarded by EXECUTOR_PODS_BY_IPS_LOCK + private val executorPodsByIPs = new mutable.HashMap[String, Pod] + private val podsWithKnownExitReasons: concurrent.Map[String, ExecutorExited] = + new ConcurrentHashMap[String, ExecutorExited]().asScala + private val disconnectedPodsByExecutorIdPendingRemoval = + new ConcurrentHashMap[String, Pod]().asScala + + private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse( + throw new SparkException("Must specify the driver pod name")) + private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( + requestExecutorsService) + + private val driverPod = try { + kubernetesClient.pods().inNamespace(kubernetesNamespace). + withName(kubernetesDriverPodName).get() + } catch { + case throwable: Throwable => + logError(s"Executor cannot find driver pod.", throwable) + throw new SparkException(s"Executor cannot find driver pod", throwable) + } + + override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + private val executorWatchResource = new AtomicReference[Closeable] + protected var totalExpectedExecutors = new AtomicInteger(0) + + private val driverUrl = RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + + private val initialExecutors = getInitialTargetExecutorNumber() + + private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + require(podAllocationInterval > 0, s"Allocation batch delay " + + s"${KUBERNETES_ALLOCATION_BATCH_DELAY} " + + s"is ${podAllocationInterval}, should be a positive integer") + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + require(podAllocationSize > 0, s"Allocation batch size " + + s"${KUBERNETES_ALLOCATION_BATCH_SIZE} " + + s"is ${podAllocationSize}, should be a positive integer") + + private val allocatorRunnable = new Runnable { + + // Maintains a map of executor id to count of checks performed to learn the loss reason + // for an executor. + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] + + override def run(): Unit = { + handleDisconnectedExecutors() + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + if (totalRegisteredExecutors.get() < runningExecutorsToPods.size) { + logDebug("Waiting for pending executors before scaling") + } else if (totalExpectedExecutors.get() <= runningExecutorsToPods.size) { + logDebug("Maximum allowed executor limit reached. Not scaling up further.") + } else { + val nodeToLocalTaskCount = getNodesWithLocalTaskCounts + for (i <- 0 until math.min( + totalExpectedExecutors.get - runningExecutorsToPods.size, podAllocationSize)) { + val (executorId, pod) = allocateNewExecutorPod(nodeToLocalTaskCount) + runningExecutorsToPods.put(executorId, pod) + runningPodsToExecutors.put(pod.getMetadata.getName, executorId) + logInfo( + s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") + } + } + } + } + + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + val disconnectedPodsByExecutorIdPendingRemovalCopy = + Map.empty ++ disconnectedPodsByExecutorIdPendingRemoval + disconnectedPodsByExecutorIdPendingRemovalCopy.foreach { case (executorId, executorPod) => + val knownExitReason = podsWithKnownExitReasons.remove(executorPod.getMetadata.getName) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logDebug(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We keep around executors that have exit conditions caused by the application. This + // allows them to be debugged later on. Otherwise, mark them as to be deleted from the + // the API server. + if (!executorExited.exitCausedByApp) { + deleteExecutorFromClusterAndDataStructures(executorId) + } + } + } + } + + def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= MAX_EXECUTOR_LOST_REASON_CHECKS) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) + } else { + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + disconnectedPodsByExecutorIdPendingRemoval -= executorId + executorReasonCheckAttemptCounts -= executorId + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).map { pod => + kubernetesClient.pods().delete(pod) + runningPodsToExecutors.remove(pod.getMetadata.getName) + }.getOrElse(logWarning(s"Unable to remove pod for unknown executor $executorId")) + } + } + } + + private def getInitialTargetExecutorNumber(defaultNumExecutors: Int = 1): Int = { + if (Utils.isDynamicAllocationEnabled(conf)) { + val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) + val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", 1) + require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, + s"initial executor number $initialNumExecutors must between min executor number " + + s"$minNumExecutors and max executor number $maxNumExecutors") + + initialNumExecutors + } else { + conf.getInt("spark.executor.instances", defaultNumExecutors) + } + + } + + override def applicationId(): String = conf.get("spark.app.id", super.applicationId()) + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + } + + override def start(): Unit = { + super.start() + executorWatchResource.set( + kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .watch(new ExecutorPodsWatcher())) + + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + + if (!Utils.isDynamicAllocationEnabled(conf)) { + doRequestTotalExecutors(initialExecutors) + } + } + + override def stop(): Unit = { + // stop allocation of new resources and caches. + allocatorExecutor.shutdown() + + // send stop message to executors so they shut down cleanly + super.stop() + + // then delete the executor pods + // TODO investigate why Utils.tryLogNonFatalError() doesn't work in this context. + // When using Utils.tryLogNonFatalError some of the code fails but without any logs or + // indication as to why. + try { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.values.foreach(kubernetesClient.pods().delete(_)) + runningExecutorsToPods.clear() + runningPodsToExecutors.clear() + } + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.clear() + } + val resource = executorWatchResource.getAndSet(null) + if (resource != null) { + resource.close() + } + } catch { + case e: Throwable => logError("Uncaught exception while shutting down controllers.", e) + } + try { + logInfo("Closing kubernetes client") + kubernetesClient.close() + } catch { + case e: Throwable => logError("Uncaught exception closing Kubernetes client.", e) + } + } + + /** + * @return A map of K8s cluster nodes to the number of tasks that could benefit from data + * locality if an executor launches on the cluster node. + */ + private def getNodesWithLocalTaskCounts() : Map[String, Int] = { + val executorPodsWithIPs = EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.values.toList // toList makes a defensive copy. + } + val nodeToLocalTaskCount = mutable.Map[String, Int]() ++ + KubernetesClusterSchedulerBackend.this.synchronized { + hostToLocalTaskCount + } + for (pod <- executorPodsWithIPs) { + // Remove cluster nodes that are running our executors already. + // TODO: This prefers spreading out executors across nodes. In case users want + // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut + // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html + nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || + nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || + nodeToLocalTaskCount.remove( + InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + } + nodeToLocalTaskCount.toMap[String, Int] + } + + /** + * Allocates a new executor pod + * + * @param nodeToLocalTaskCount A map of K8s cluster nodes to the number of tasks that could + * benefit from data locality if an executor launches on the cluster + * node. + * @return A tuple of the new executor name and the Pod data structure. + */ + private def allocateNewExecutorPod(nodeToLocalTaskCount: Map[String, Int]): (String, Pod) = { + val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString + val executorPod = executorPodFactory.createExecutorPod( + executorId, + applicationId(), + driverUrl, + conf.getExecutorEnv, + driverPod, + nodeToLocalTaskCount) + try { + (executorId, kubernetesClient.pods.create(executorPod)) + } catch { + case throwable: Throwable => + logError("Failed to allocate executor pod.", throwable) + throw throwable + } + } + + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { + totalExpectedExecutors.set(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + for (executor <- executorIds) { + val maybeRemovedExecutor = runningExecutorsToPods.remove(executor) + maybeRemovedExecutor.foreach { executorPod => + kubernetesClient.pods().delete(executorPod) + disconnectedPodsByExecutorIdPendingRemoval(executor) = executorPod + runningPodsToExecutors.remove(executorPod.getMetadata.getName) + } + if (maybeRemovedExecutor.isEmpty) { + logWarning(s"Unable to remove pod for unknown executor $executor") + } + } + } + true + } + + def getExecutorPodByIP(podIP: String): Option[Pod] = { + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.get(podIP) + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + + private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 + + override def eventReceived(action: Action, pod: Pod): Unit = { + if (action == Action.MODIFIED && pod.getStatus.getPhase == "Running" + && pod.getMetadata.getDeletionTimestamp == null) { + val podIP = pod.getStatus.getPodIP + val clusterNodeName = pod.getSpec.getNodeName + logDebug(s"Executor pod $pod ready, launched at $clusterNodeName as IP $podIP.") + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs += ((podIP, pod)) + } + } else if ((action == Action.MODIFIED && pod.getMetadata.getDeletionTimestamp != null) || + action == Action.DELETED || action == Action.ERROR) { + val podName = pod.getMetadata.getName + val podIP = pod.getStatus.getPodIP + logDebug(s"Executor pod $podName at IP $podIP was at $action.") + if (podIP != null) { + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs -= podIP + } + } + if (action == Action.ERROR) { + logInfo(s"Received pod $podName exited event. Reason: " + pod.getStatus.getReason) + handleErroredPod(pod) + } else if (action == Action.DELETED) { + logInfo(s"Received delete pod $podName event. Reason: " + pod.getStatus.getReason) + handleDeletedPod(pod) + } + } + } + + override def onClose(cause: KubernetesClientException): Unit = { + logDebug("Executor pod watch closed.", cause) + } + + def getExecutorExitStatus(pod: Pod): Int = { + val containerStatuses = pod.getStatus.getContainerStatuses + if (!containerStatuses.isEmpty) { + // we assume the first container represents the pod status. This assumption may not hold + // true in the future. Revisit this if side-car containers start running inside executor + // pods. + getExecutorExitStatus(containerStatuses.get(0)) + } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS + } + + def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { + Option(containerStatus.getState).map(containerState => + Option(containerState.getTerminated).map(containerStateTerminated => + containerStateTerminated.getExitCode.intValue()).getOrElse(UNKNOWN_EXIT_CODE) + ).getOrElse(UNKNOWN_EXIT_CODE) + } + + def isPodAlreadyReleased(pod: Pod): Boolean = { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + !runningPodsToExecutors.contains(pod.getMetadata.getName) + } + } + + def handleErroredPod(pod: Pod): Unit = { + val containerExitStatus = getExecutorExitStatus(pod) + // container was probably actively killed by the driver. + val exitReason = if (isPodAlreadyReleased(pod)) { + ExecutorExited(containerExitStatus, exitCausedByApp = false, + s"Container in pod " + pod.getMetadata.getName + + " exited from explicit termination request.") + } else { + val containerExitReason = containerExitStatus match { + case VMEM_EXCEEDED_EXIT_CODE | PMEM_EXCEEDED_EXIT_CODE => + memLimitExceededLogMessage(pod.getStatus.getReason) + case _ => + // Here we can't be sure that that exit was caused by the application but this seems + // to be the right default since we know the pod was not explicitly deleted by + // the user. + s"Pod ${pod.getMetadata.getName}'s executor container exited with exit status" + + s" code $containerExitStatus." + } + ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) + } + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) + } + + def handleDeletedPod(pod: Pod): Unit = { + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + val exitReason = ExecutorExited( + getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) + } + } +} + +private object KubernetesClusterSchedulerBackend { + private val VMEM_EXCEEDED_EXIT_CODE = -103 + private val PMEM_EXCEEDED_EXIT_CODE = -104 + private val UNKNOWN_EXIT_CODE = -111 + // Number of times we are allowed check for the loss reason for an executor before we give up + // and assume the executor failed for good, and attribute it to a framework fault. + val MAX_EXECUTOR_LOST_REASON_CHECKS = 10 + + def memLimitExceededLogMessage(diagnostics: String): String = { + s"Pod/Container killed for exceeding memory limits. $diagnostics" + + " Consider boosting spark executor memory overhead." + } +} + diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..ad95fadb7c0c0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala new file mode 100644 index 0000000000000..515ec413f4e31 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.KubernetesClient +import org.mockito.MockitoAnnotations +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants +import org.apache.spark.network.netty.SparkTransportConf + +class ExecutorPodFactoryImplSuite extends SparkFunSuite with BeforeAndAfter { + private val driverPodName: String = "driver-pod" + private val driverPodUid: String = "driver-uid" + private val driverUrl: String = "driver-url" + private val executorPrefix: String = "base" + private val executorImage: String = "executor-image" + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .withUid(driverPodUid) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + private var sc: SparkContext = _ + + before { + SparkContext.clearActiveContext() + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) + .set(EXECUTOR_DOCKER_IMAGE, executorImage) + sc = new SparkContext("local", "test") + } + private var kubernetesClient: KubernetesClient = _ + + test("basic executor pod has reasonable defaults") { + val factory = new ExecutorPodFactoryImpl(baseConf) + val executor = factory.createExecutorPod("1", "dummy", "dummy", + Seq[(String, String)](), driverPod, Map[String, Int]()) + + // The executor pod name and default labels. + assert(executor.getMetadata.getName == s"$executorPrefix-exec-1") + assert(executor.getMetadata.getLabels.size() == 3) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.getSpec.getContainers.size() == 1) + assert(executor.getSpec.getContainers.get(0).getImage == executorImage) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() == 0) + assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() == 1) + assert(executor.getSpec.getContainers.get(0).getResources. + getLimits.get("memory").getAmount == "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.getSpec.getNodeSelector.size() == 0) + assert(executor.getSpec.getVolumes.size() == 0) + + checkEnv(executor, Set()) + checkOwnerReferences(executor, driverPodUid) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, + "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod("1", + "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getHostname.length == 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod("1", + "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) + + checkEnv(executor, Set("SPARK_JAVA_OPT_0", "SPARK_EXECUTOR_EXTRA_CLASSPATH", "qux")) + checkOwnerReferences(executor, driverPodUid) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() == 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid == driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController == true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executor: Pod, additionalEnvVars: Set[String]): Unit = { + val defaultEnvs = Set(constants.ENV_EXECUTOR_ID, + constants.ENV_DRIVER_URL, constants.ENV_EXECUTOR_CORES, + constants.ENV_EXECUTOR_MEMORY, constants.ENV_APPLICATION_ID, + constants.ENV_MOUNTED_CLASSPATH, constants.ENV_EXECUTOR_POD_IP, + constants.ENV_EXECUTOR_PORT) ++ additionalEnvVars + + assert(executor.getSpec.getContainers.size() == 1) + assert(executor.getSpec.getContainers.get(0).getEnv().size() == defaultEnvs.size) + val setEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { + x => x.getName + }.toSet + assert(defaultEnvs == setEnvs) + } +} \ No newline at end of file diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..1089ea7ecb8e5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.scalatest.BeforeAndAfter +import org.scalatest.mock.MockitoSugar._ +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpoint, RpcEndpointAddress, RpcEndpointRef, RpcEnv, RpcTimeout} +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +private[spark] class KubernetesClusterSchedulerBackendSuite + extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELLED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELLED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set("spark.app.id", APP_ID) + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(mock[Future[Boolean]]) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend(true) + scheduler.start() + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(SECOND_EXECUTOR_POD) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations, never()).create(SECOND_EXECUTOR_POD) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Deleting executors and then running an allocator pass after finding the loss reason" + + " should only delete the pod once.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(FIRST_EXECUTOR_POD, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that disconnect from application errors are noted as exits caused by app.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(FIRST_EXECUTOR_POD, 1)) + + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should only try to get the loss reason a number of times before giving up and" + + " removing the executor.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to KubernetesClusterSchedulerBackend.MAX_EXECUTOR_LOST_REASON_CHECKS foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + private def newSchedulerBackend(externalShuffle: Boolean): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(FIRST_EXECUTOR_POD) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Unit = { + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(expectedPod) + } + +} From 5a5ce256e6a458eb190af8c9435b455d15360c4e Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Sep 2017 14:46:45 -0700 Subject: [PATCH 32/37] Fix unit tests --- .../KubernetesClusterSchedulerBackend.scala | 22 ++++++++++++++++ ...bernetesClusterSchedulerBackendSuite.scala | 25 ++++++++++--------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala index 42a3e3fd50492..351b0a880b810 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -427,6 +427,28 @@ private[spark] class KubernetesClusterSchedulerBackend( podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) } } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new KubernetesDriverEndpoint(rpcEnv, properties) + } + + private class KubernetesDriverEndpoint( + rpcEnv: RpcEnv, + sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + } + } + } + } + } + } } private object KubernetesClusterSchedulerBackend { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala index 1089ea7ecb8e5..25e590ee35f0c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala @@ -18,22 +18,23 @@ package org.apache.spark.scheduler.cluster.kubernetes import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} +import scala.collection.JavaConverters._ +import scala.concurrent.Future + import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.mockito.Mockito.{mock => _, _} import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.kubernetes.config._ import org.apache.spark.deploy.kubernetes.constants._ -import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpoint, RpcEndpointAddress, RpcEndpointRef, RpcEnv, RpcTimeout} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -174,7 +175,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite } test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() assert(executorPodsWatcherArgument.getValue != null) assert(allocatorRunnable.getValue != null) @@ -186,7 +187,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() requestExecutorRunnable.getValue.run() expectPodCreationWithId(1, FIRST_EXECUTOR_POD) @@ -201,7 +202,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() requestExecutorRunnable.getValue.run() expectPodCreationWithId(1, FIRST_EXECUTOR_POD) @@ -219,7 +220,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() requestExecutorRunnable.getValue.run() when(podOperations.create(any(classOf[Pod]))) @@ -243,7 +244,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() requestExecutorRunnable.getValue.run() when(podOperations.create(any(classOf[Pod]))) @@ -280,7 +281,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() expectPodCreationWithId(1, FIRST_EXECUTOR_POD) when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) @@ -315,7 +316,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite sparkConf .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend(true) + val scheduler = newSchedulerBackend() scheduler.start() expectPodCreationWithId(1, FIRST_EXECUTOR_POD) when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) @@ -340,7 +341,7 @@ private[spark] class KubernetesClusterSchedulerBackendSuite RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - private def newSchedulerBackend(externalShuffle: Boolean): KubernetesClusterSchedulerBackend = { + private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, From c423539afaf96f2c6768646b00bdbb4c9c5ed3d7 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Sep 2017 15:24:46 -0700 Subject: [PATCH 33/37] Cleaned up extraneous constants --- .../spark/deploy/kubernetes/config.scala | 356 ------------------ .../spark/deploy/kubernetes/constants.scala | 28 -- .../KubernetesClusterSchedulerBackend.scala | 15 +- 3 files changed, 7 insertions(+), 392 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala index 9dfd13e1817f8..53f3d5e60c658 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala @@ -59,24 +59,12 @@ package object config extends Logging { "spark.kubernetes.authenticate.driver" private[spark] val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = "spark.kubernetes.authenticate.driver.mounted" - private[spark] val APISERVER_AUTH_RESOURCE_STAGING_SERVER_CONF_PREFIX = - "spark.kubernetes.authenticate.resourceStagingServer" - private[spark] val APISERVER_AUTH_SHUFFLE_SERVICE_CONF_PREFIX = - "spark.kubernetes.authenticate.shuffleService" private[spark] val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" private[spark] val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" private[spark] val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" private[spark] val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" private[spark] val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" - private[spark] val RESOURCE_STAGING_SERVER_USE_SERVICE_ACCOUNT_CREDENTIALS = - ConfigBuilder( - s"$APISERVER_AUTH_RESOURCE_STAGING_SERVER_CONF_PREFIX.useServiceAccountCredentials") - .doc("Use a service account token and CA certificate in the resource staging server to" + - " watch the API server's objects.") - .booleanConf - .createWithDefault(true) - private[spark] val KUBERNETES_SERVICE_ACCOUNT_NAME = ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") .doc("Service account that is used when running the driver pod. The driver pod uses" + @@ -86,13 +74,6 @@ package object config extends Logging { .stringConf .createOptional - private[spark] val SPARK_SHUFFLE_SERVICE_HOST = - ConfigBuilder("spark.shuffle.service.host") - .doc("Host for Spark Shuffle Service") - .internal() - .stringConf - .createOptional - // Note that while we set a default for this when we start up the // scheduler, the specific default value is dynamically determined // based on the executor memory. @@ -168,45 +149,6 @@ package object config extends Logging { .stringConf .createWithDefault("spark") - private[spark] val KUBERNETES_SHUFFLE_NAMESPACE = - ConfigBuilder("spark.kubernetes.shuffle.namespace") - .doc("Namespace of the shuffle service") - .stringConf - .createWithDefault("default") - - private[spark] val KUBERNETES_SHUFFLE_SVC_IP = - ConfigBuilder("spark.kubernetes.shuffle.ip") - .doc("This setting is for debugging only. Setting this " + - "allows overriding the IP that the executor thinks its colocated " + - "shuffle service is on") - .stringConf - .createOptional - - private[spark] val KUBERNETES_SHUFFLE_LABELS = - ConfigBuilder("spark.kubernetes.shuffle.labels") - .doc("Labels to identify the shuffle service") - .stringConf - .createOptional - - private[spark] val KUBERNETES_SHUFFLE_DIR = - ConfigBuilder("spark.kubernetes.shuffle.dir") - .doc("Path to the shared shuffle directories.") - .stringConf - .createOptional - - private[spark] val KUBERNETES_SHUFFLE_APISERVER_URI = - ConfigBuilder("spark.kubernetes.shuffle.apiServer.url") - .doc("URL to the Kubernetes API server that the shuffle service will monitor for Spark pods.") - .stringConf - .createWithDefault(KUBERNETES_MASTER_INTERNAL_URL) - - private[spark] val KUBERNETES_SHUFFLE_USE_SERVICE_ACCOUNT_CREDENTIALS = - ConfigBuilder(s"$APISERVER_AUTH_SHUFFLE_SERVICE_CONF_PREFIX.useServiceAccountCredentials") - .doc("Whether or not to use service account credentials when contacting the API server from" + - " the shuffle service.") - .booleanConf - .createWithDefault(true) - private[spark] val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of dynamic allocation. ") @@ -219,216 +161,6 @@ package object config extends Logging { .longConf .createWithDefault(1) - private[spark] val WAIT_FOR_APP_COMPLETION = - ConfigBuilder("spark.kubernetes.submission.waitAppCompletion") - .doc("In cluster mode, whether to wait for the application to finish before exiting the" + - " launcher process.") - .booleanConf - .createWithDefault(true) - - private[spark] val REPORT_INTERVAL = - ConfigBuilder("spark.kubernetes.report.interval") - .doc("Interval between reports of the current app status in cluster mode.") - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefaultString("1s") - - // Spark resource staging server. - private[spark] val RESOURCE_STAGING_SERVER_API_SERVER_URL = - ConfigBuilder("spark.kubernetes.resourceStagingServer.apiServer.url") - .doc("URL for the Kubernetes API server. The resource staging server monitors the API" + - " server to check when pods no longer are using mounted resources. Note that this isn't" + - " to be used in Spark applications, as the API server URL should be set via spark.master.") - .stringConf - .createWithDefault(KUBERNETES_MASTER_INTERNAL_URL) - - private[spark] val RESOURCE_STAGING_SERVER_API_SERVER_CA_CERT_FILE = - ConfigBuilder("spark.kubernetes.resourceStagingServer.apiServer.caCertFile") - .doc("CA certificate for the resource staging server to use when contacting the Kubernetes" + - " API server over TLS.") - .stringConf - .createOptional - - private[spark] val RESOURCE_STAGING_SERVER_PORT = - ConfigBuilder("spark.kubernetes.resourceStagingServer.port") - .doc("Port for the Kubernetes resource staging server to listen on.") - .intConf - .createWithDefault(10000) - - private[spark] val RESOURCE_STAGING_SERVER_INITIAL_ACCESS_EXPIRATION_TIMEOUT = - ConfigBuilder("spark.kubernetes.resourceStagingServer.initialAccessExpirationTimeout") - .doc("The resource staging server will wait for any resource bundle to be accessed for a" + - " first time for this period. If this timeout expires before the resources are accessed" + - " the first time, the resources are cleaned up under the assumption that the dependents" + - " of the given resource bundle failed to launch at all.") - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefaultString("30m") - - private[spark] val RESOURCE_STAGING_SERVER_KEY_PEM = - ConfigBuilder("spark.ssl.kubernetes.resourceStagingServer.keyPem") - .doc("Key PEM file to use when having the Kubernetes dependency server listen on TLS.") - .stringConf - .createOptional - - private[spark] val RESOURCE_STAGING_SERVER_SSL_NAMESPACE = "kubernetes.resourceStagingServer" - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE = - "kubernetes.resourceStagingServer.internal" - private[spark] val RESOURCE_STAGING_SERVER_CERT_PEM = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.serverCertPem") - .doc("Certificate PEM file to use when having the resource staging server" + - " listen on TLS.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_CLIENT_CERT_PEM = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.clientCertPem") - .doc("Certificate PEM file to use when the client contacts the resource staging server." + - " This must strictly be a path to a file on the submitting machine's disk.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_CLIENT_CERT_PEM = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.clientCertPem") - .doc("Certificate PEM file to use when the init-container contacts the resource staging" + - " server. If this is not provided, it defaults to the value of" + - " spark.ssl.kubernetes.resourceStagingServer.clientCertPem. This can be a URI with" + - " a scheme of local:// which denotes that the file is pre-mounted on the init-container's" + - " disk. A uri without a scheme or a scheme of file:// will result in this file being" + - " mounted from the submitting machine's disk as a secret into the pods.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_KEYSTORE_PASSWORD_FILE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.keyStorePasswordFile") - .doc("File containing the keystore password for the Kubernetes resource staging server.") - .stringConf - .createOptional - - private[spark] val RESOURCE_STAGING_SERVER_KEYSTORE_KEY_PASSWORD_FILE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.keyPasswordFile") - .doc("File containing the key password for the Kubernetes resource staging server.") - .stringConf - .createOptional - - private[spark] val RESOURCE_STAGING_SERVER_SSL_ENABLED = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.enabled") - .doc("Whether or not to use SSL when communicating with the resource staging server.") - .booleanConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_SSL_ENABLED = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.enabled") - .doc("Whether or not to use SSL when communicating with the resource staging server from" + - " the init-container. If this is not provided, defaults to" + - " the value of spark.ssl.kubernetes.resourceStagingServer.enabled") - .booleanConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_FILE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStore") - .doc("File containing the trustStore to communicate with the Kubernetes dependency server." + - " This must strictly be a path on the submitting machine's disk.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_FILE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStore") - .doc("File containing the trustStore to communicate with the Kubernetes dependency server" + - " from the init-container. If this is not provided, defaults to the value of" + - " spark.ssl.kubernetes.resourceStagingServer.trustStore. This can be a URI with a scheme" + - " of local:// indicating that the trustStore is pre-mounted on the init-container's" + - " disk. If no scheme, or a scheme of file:// is provided, this file is mounted from the" + - " submitting machine's disk as a Kubernetes secret into the pods.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_PASSWORD = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStorePassword") - .doc("Password for the trustStore for communicating to the dependency server.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_PASSWORD = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStorePassword") - .doc("Password for the trustStore for communicating to the dependency server from the" + - " init-container. If this is not provided, defaults to" + - " spark.ssl.kubernetes.resourceStagingServer.trustStorePassword.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_TRUSTSTORE_TYPE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_SSL_NAMESPACE.trustStoreType") - .doc("Type of trustStore for communicating with the dependency server.") - .stringConf - .createOptional - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_TRUSTSTORE_TYPE = - ConfigBuilder(s"spark.ssl.$RESOURCE_STAGING_SERVER_INTERNAL_SSL_NAMESPACE.trustStoreType") - .doc("Type of trustStore for communicating with the dependency server from the" + - " init-container. If this is not provided, defaults to" + - " spark.ssl.kubernetes.resourceStagingServer.trustStoreType") - .stringConf - .createOptional - - // Driver and Init-Container parameters - private[spark] val RESOURCE_STAGING_SERVER_URI = - ConfigBuilder("spark.kubernetes.resourceStagingServer.uri") - .doc("Base URI for the Spark resource staging server.") - .stringConf - .createOptional - - private[spark] val RESOURCE_STAGING_SERVER_INTERNAL_URI = - ConfigBuilder("spark.kubernetes.resourceStagingServer.internal.uri") - .doc("Base URI for the Spark resource staging server when the init-containers access it for" + - " downloading resources. If this is not provided, it defaults to the value provided in" + - " spark.kubernetes.resourceStagingServer.uri, the URI that the submission client uses to" + - " upload the resources from outside the cluster.") - .stringConf - .createOptional - - private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_RESOURCE_IDENTIFIER = - ConfigBuilder("spark.kubernetes.initcontainer.downloadJarsResourceIdentifier") - .doc("Identifier for the jars tarball that was uploaded to the staging service.") - .internal() - .stringConf - .createOptional - - private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_SECRET_LOCATION = - ConfigBuilder("spark.kubernetes.initcontainer.downloadJarsSecretLocation") - .doc("Location of the application secret to use when the init-container contacts the" + - " resource staging server to download jars.") - .internal() - .stringConf - .createWithDefault(s"$INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH/" + - s"$INIT_CONTAINER_SUBMITTED_JARS_SECRET_KEY") - - private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_RESOURCE_IDENTIFIER = - ConfigBuilder("spark.kubernetes.initcontainer.downloadFilesResourceIdentifier") - .doc("Identifier for the files tarball that was uploaded to the staging service.") - .internal() - .stringConf - .createOptional - - private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_SECRET_LOCATION = - ConfigBuilder("spark.kubernetes.initcontainer.downloadFilesSecretLocation") - .doc("Location of the application secret to use when the init-container contacts the" + - " resource staging server to download files.") - .internal() - .stringConf - .createWithDefault( - s"$INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH/$INIT_CONTAINER_SUBMITTED_FILES_SECRET_KEY") - - private[spark] val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initcontainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is" + - " calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - private[spark] val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initcontainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is" + - " calculated from spark.files.") - .internal() - .stringConf - .createOptional - - private[spark] val INIT_CONTAINER_DOCKER_IMAGE = - ConfigBuilder("spark.kubernetes.initcontainer.docker.image") - .doc("Image for the driver and executor's init-container that downloads dependencies.") - .stringConf - .createWithDefault(s"spark-init:$sparkVersion") - private[spark] val INIT_CONTAINER_JARS_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountdependencies.jarsDownloadDir") .doc("Location to download jars to in the driver and executors. When using" + @@ -437,94 +169,6 @@ package object config extends Logging { .stringConf .createWithDefault("/var/spark-data/spark-jars") - private[spark] val INIT_CONTAINER_FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountdependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - private[spark] val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountdependencies.mountTimeout") - .doc("Timeout before aborting the attempt to download and unpack local dependencies from" + - " remote locations and the resource staging server when initializing the driver and" + - " executor pods.") - .timeConf(TimeUnit.MINUTES) - .createWithDefault(5) - - private[spark] val EXECUTOR_SUBMITTED_SMALL_FILES_SECRET = - ConfigBuilder("spark.kubernetes.mountdependencies.smallfiles.executor.secretName") - .doc("Name of the secret that should be mounted into the executor containers for" + - " distributing submitted small files without the resource staging server.") - .internal() - .stringConf - .createOptional - - private[spark] val EXECUTOR_SUBMITTED_SMALL_FILES_SECRET_MOUNT_PATH = - ConfigBuilder("spark.kubernetes.mountdependencies.smallfiles.executor.secretMountPath") - .doc(s"Mount path in the executors for the secret given by" + - s" ${EXECUTOR_SUBMITTED_SMALL_FILES_SECRET.key}") - .internal() - .stringConf - .createOptional - - private[spark] val EXECUTOR_INIT_CONTAINER_CONFIG_MAP = - ConfigBuilder("spark.kubernetes.initcontainer.executor.configmapname") - .doc("Name of the config map to use in the init-container that retrieves submitted files" + - " for the executor.") - .internal() - .stringConf - .createOptional - - private[spark] val EXECUTOR_INIT_CONTAINER_CONFIG_MAP_KEY = - ConfigBuilder("spark.kubernetes.initcontainer.executor.configmapkey") - .doc("Key for the entry in the init container config map for submitted files that" + - " corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - - private[spark] val EXECUTOR_INIT_CONTAINER_SECRET = - ConfigBuilder("spark.kubernetes.initcontainer.executor.stagingServerSecret.name") - .doc("Name of the secret to mount into the init-container that retrieves submitted files.") - .internal() - .stringConf - .createOptional - - private[spark] val EXECUTOR_INIT_CONTAINER_SECRET_MOUNT_DIR = - ConfigBuilder("spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir") - .doc("Directory to mount the resource staging server secrets into for the executor" + - " init-containers. This must be exactly the same as the directory that the submission" + - " client mounted the secret into because the config map's properties specify the" + - " secret location as to be the same between the driver init-container and the executor" + - " init-container. Thus the submission client will always set this and the driver will" + - " never rely on a constant or convention, in order to protect against cases where the" + - " submission client has a different version from the driver itself, and hence might" + - " have different constants loaded in constants.scala.") - .internal() - .stringConf - .createOptional - - private[spark] val KUBERNETES_DRIVER_LIMIT_CORES = - ConfigBuilder("spark.kubernetes.driver.limit.cores") - .doc("Specify the hard cpu limit for the driver pod") - .stringConf - .createOptional - - private[spark] val KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED = - ConfigBuilder("spark.kubernetes.driver.hdfslocality.clusterNodeNameDNSLookup.enabled") - .doc("Whether or not HDFS locality support code should look up DNS for full hostnames of" + - " cluster nodes. In some K8s clusters, notably GKE, cluster node names are short" + - " hostnames, and so comparing them against HDFS datanode hostnames always fail. To fix," + - " enable this flag. This is disabled by default because DNS lookup can be expensive." + - " The driver can slow down and fail to respond to executor heartbeats in time." + - " If enabling this flag, make sure your DNS server has enough capacity" + - " for the workload.") - .internal() - .booleanConf - .createWithDefault(false) - private[spark] val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for a single executor pod") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala index 0a2bc46249f3a..07bf56e8bf0dc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala @@ -18,13 +18,11 @@ package org.apache.spark.deploy.kubernetes package object constants { // Labels - private[spark] val SPARK_DRIVER_LABEL = "spark-driver" private[spark] val SPARK_APP_ID_LABEL = "spark-app-selector" private[spark] val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" private[spark] val SPARK_ROLE_LABEL = "spark-role" private[spark] val SPARK_POD_DRIVER_ROLE = "driver" private[spark] val SPARK_POD_EXECUTOR_ROLE = "executor" - private[spark] val SPARK_APP_NAME_ANNOTATION = "spark-app-name" // Credentials secrets private[spark] val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = @@ -70,34 +68,8 @@ package object constants { private[spark] val ENV_PYSPARK_FILES = "PYSPARK_FILES" private[spark] val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" private[spark] val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - private[spark] val ENV_MOUNTED_FILES_FROM_SECRET_DIR = "SPARK_MOUNTED_FILES_FROM_SECRET_DIR" - - // Bootstrapping dependencies with the init-container - private[spark] val INIT_CONTAINER_ANNOTATION = "pod.beta.kubernetes.io/init-containers" - private[spark] val INIT_CONTAINER_SECRET_VOLUME_MOUNT_PATH = - "/mnt/secrets/spark-init" - private[spark] val INIT_CONTAINER_SUBMITTED_JARS_SECRET_KEY = - "downloadSubmittedJarsSecret" - private[spark] val INIT_CONTAINER_SUBMITTED_FILES_SECRET_KEY = - "downloadSubmittedFilesSecret" - private[spark] val INIT_CONTAINER_STAGING_SERVER_TRUSTSTORE_SECRET_KEY = "trustStore" - private[spark] val INIT_CONTAINER_STAGING_SERVER_CLIENT_CERT_SECRET_KEY = "ssl-certificate" - private[spark] val INIT_CONTAINER_CONFIG_MAP_KEY = "download-submitted-files" - private[spark] val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - private[spark] val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files" - private[spark] val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - private[spark] val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - private[spark] val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - private[spark] val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - private[spark] val DEFAULT_SHUFFLE_MOUNT_NAME = "shuffle" - private[spark] val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" - - // Bootstrapping dependencies via a secret - private[spark] val MOUNTED_SMALL_FILES_SECRET_MOUNT_PATH = "/etc/spark-submitted-files" // Miscellaneous - private[spark] val ANNOTATION_EXECUTOR_NODE_AFFINITY = "scheduler.alpha.kubernetes.io/affinity" private[spark] val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" private[spark] val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" private[spark] val MEMORY_OVERHEAD_FACTOR = 0.10 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala index 351b0a880b810..35589361d046a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -18,23 +18,22 @@ package org.apache.spark.scheduler.cluster.kubernetes import java.io.Closeable import java.net.InetAddress -import java.util.Collections -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action import scala.collection.{concurrent, mutable} import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.{SparkEnv, SparkException} +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.SparkException import org.apache.spark.deploy.kubernetes.config._ import org.apache.spark.deploy.kubernetes.constants._ -import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpointAddress, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RetrieveSparkAppConfig, SparkAppConfig} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils From 12f9c218507ee63697811dd0602d45630e2c9940 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Sep 2017 15:48:05 -0700 Subject: [PATCH 34/37] Cleaned POMs --- pom.xml | 7 +++++ resource-managers/kubernetes/core/pom.xml | 38 ----------------------- 2 files changed, 7 insertions(+), 38 deletions(-) diff --git a/pom.xml b/pom.xml index 0bbbf20a76d68..42ffe2fb405f5 100644 --- a/pom.xml +++ b/pom.xml @@ -2637,6 +2637,13 @@ + + kubernetes + + resource-managers/kubernetes/core + + + hive-thriftserver diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a4b18c527c969..b21f24af5c0c9 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -70,37 +70,7 @@ - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${fasterxml.jackson.version} - - - org.glassfish.jersey.containers - jersey-container-servlet - - - org.glassfish.jersey.media - jersey-media-multipart - - - com.squareup.retrofit2 - retrofit - - - com.squareup.retrofit2 - converter-jackson - - - com.squareup.retrofit2 - converter-scalars - - - com.fasterxml.jackson.jaxrs - jackson-jaxrs-json-provider - javax.ws.rs javax.ws.rs-api @@ -112,14 +82,6 @@ - - org.bouncycastle - bcpkix-jdk15on - - - org.bouncycastle - bcprov-jdk15on - org.mockito mockito-core From bb7b0fb36c4d1da1042d066ec024801ceca37459 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Sep 2017 16:42:56 -0700 Subject: [PATCH 35/37] Fix pom --- resource-managers/kubernetes/core/pom.xml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index b21f24af5c0c9..1637c0f7aa716 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -71,10 +71,13 @@ + - javax.ws.rs - javax.ws.rs-api + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + com.google.guava From e2e45dc3f0f6ed20a875c147c88aad0d8a65b8eb Mon Sep 17 00:00:00 2001 From: foxish Date: Wed, 20 Sep 2017 01:00:11 -0700 Subject: [PATCH 36/37] Clean up deprecated configuration --- .../kubernetes/ConfigurationUtils.scala | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala index 1a008c236d00f..aafb6f3aabe6d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala @@ -41,31 +41,6 @@ object ConfigurationUtils extends Logging { }).getOrElse(Map.empty[String, String]) } - def combinePrefixedKeyValuePairsWithDeprecatedConf( - sparkConf: SparkConf, - prefix: String, - deprecatedConf: OptionalConfigEntry[String], - configType: String): Map[String, String] = { - val deprecatedKeyValuePairsString = sparkConf.get(deprecatedConf) - deprecatedKeyValuePairsString.foreach { _ => - logWarning(s"Configuration with key ${deprecatedConf.key} is deprecated. Use" + - s" configurations with prefix $prefix instead.") - } - val fromDeprecated = parseKeyValuePairs( - deprecatedKeyValuePairsString, - deprecatedConf.key, - configType) - val fromPrefix = sparkConf.getAllWithPrefix(prefix) - val combined = fromDeprecated.toSeq ++ fromPrefix - combined.groupBy(_._1).foreach { - case (key, values) => - require(values.size == 1, - s"Cannot have multiple values for a given $configType key, got key $key with" + - s" values $values") - } - combined.toMap - } - def parsePrefixedKeyValuePairs( sparkConf: SparkConf, prefix: String, From 9c3d11e6b0cc81c3b9ab5e12e995da70bacc23c4 Mon Sep 17 00:00:00 2001 From: foxish Date: Wed, 20 Sep 2017 01:31:43 -0700 Subject: [PATCH 37/37] clean up imports --- .../main/scala/org/apache/spark/deploy/kubernetes/config.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala index 53f3d5e60c658..2e37749df4233 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala @@ -16,10 +16,7 @@ */ package org.apache.spark.deploy.kubernetes -import java.util.concurrent.TimeUnit - import org.apache.spark.{SPARK_VERSION => sparkVersion} -import org.apache.spark.deploy.kubernetes.constants._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.network.util.ByteUnit