diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..32c609fc --- /dev/null +++ b/.editorconfig @@ -0,0 +1,32 @@ +root = true + +[*] +charset=utf-8 +end_of_line=lf +insert_final_newline=true +indent_style=space +indent_size=4 +max_line_length=120 + +[*.json] +indent_size=2 + +[*.yaml] +indent_size=2 + +[*.ipynb] +insert_final_newline=false + +[*.{kt,kts}] +ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL + +ktlint_code_style = ktlint_official +ktlint_experimental = enabled +ktlint_standard_filename = disabled +ktlint_standard_no-empty-first-line-in-class-body = disabled +ktlint_class_signature_rule_force_multiline_when_parameter_count_greater_or_equal_than = 4 +ktlint_ignore_back_ticked_identifier = true +ktlint_standard_multiline-expression-wrapping = disabled + +[*/build/**/*] +ktlint = disabled \ No newline at end of file diff --git a/.gitignore b/.gitignore index d1f280d2..dfd37e4a 100644 --- a/.gitignore +++ b/.gitignore @@ -377,3 +377,4 @@ orcpath/ **/.allure/ **/allure-results/ /generated_* +/.kotlin/ diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index c5c068b6..a429871f 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -4,7 +4,7 @@ object Versions : Dsl { const val project = "2.0.0-SNAPSHOT" const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT" const val groupID = "org.jetbrains.kotlinx.spark" - const val kotlin = "2.0.0-RC3" + const val kotlin = "2.0.0" const val jvmTarget = "1.8" val jvmLanguageVersion = JavaLanguageVersion.of(8) const val jupyterJvmTarget = "8" diff --git a/gradle/bootstraps/compiler-plugin.jar b/gradle/bootstraps/compiler-plugin.jar index 29aa66d2..7eccc37e 100644 Binary files a/gradle/bootstraps/compiler-plugin.jar and b/gradle/bootstraps/compiler-plugin.jar differ diff --git a/gradle/bootstraps/gradle-plugin.jar b/gradle/bootstraps/gradle-plugin.jar index 10a5be5d..9a50bd53 100644 Binary files a/gradle/bootstraps/gradle-plugin.jar and b/gradle/bootstraps/gradle-plugin.jar differ diff --git a/jupyter/build.gradle.kts b/jupyter/build.gradle.kts index fb5e90de..c8cc9bad 100644 --- a/jupyter/build.gradle.kts +++ b/jupyter/build.gradle.kts @@ -47,16 +47,22 @@ dependencies { // https://github.com/FasterXML/jackson-bom/issues/52 if (Versions.spark == "3.3.1") implementation(jacksonDatabind) + if (Versions.sparkConnect) { + // IMPORTANT! + compileOnly(sparkSqlApi) + implementation(sparkConnectClient) + } else { + implementation(sparkSql) + } + api( kotlinxHtml, - sparkSql, - sparkRepl, - sparkStreaming, - hadoopClient, ) implementation( kotlinStdLib, + hadoopClient, + reflect ) testImplementation( @@ -65,6 +71,8 @@ dependencies { kotlinScriptingJvm, ) + compileOnly(scalaLibrary) + testCompileOnly(scalaLibrary) } } diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/HtmlRendering.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/HtmlRendering.kt deleted file mode 100644 index ad083962..00000000 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/HtmlRendering.kt +++ /dev/null @@ -1,136 +0,0 @@ -/*- - * =LICENSE= - * Kotlin Spark API: API for Spark 3.2+ (Scala 2.12) - * ---------- - * Copyright (C) 2019 - 2022 JetBrains - * ---------- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * =LICENSEEND= - */ -package org.jetbrains.kotlinx.spark.api.jupyter - -import kotlinx.html.* -import kotlinx.html.stream.appendHTML -import org.apache.spark.SparkException -import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.sql.Dataset -import org.apache.spark.unsafe.array.ByteArrayMethods -import org.jetbrains.kotlinx.spark.api.asKotlinIterable -import org.jetbrains.kotlinx.spark.api.asKotlinIterator -import org.jetbrains.kotlinx.spark.api.asKotlinList -import scala.Product -import java.io.InputStreamReader -import java.io.Serializable - -private fun createHtmlTable(fillTable: TABLE.() -> Unit): String = buildString { - appendHTML().head { - style("text/css") { - unsafe { - val resource = "/table.css" - val res = SparkIntegration::class.java - .getResourceAsStream(resource) ?: error("Resource '$resource' not found") - val readRes = InputStreamReader(res).readText() - raw("\n" + readRes) - } - } - } - - appendHTML().table("dataset", fillTable) -} - - -internal fun JavaRDDLike.toHtml(limit: Int = 20, truncate: Int = 30): String = try { - createHtmlTable { - val numRows = limit.coerceIn(0 until ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) - val tmpRows = take(numRows).toList() - - val hasMoreData = tmpRows.size - 1 > numRows - val rows = tmpRows.take(numRows) - - tr { th { +"Values" } } - - for (row in rows) tr { - td { - val string = when (row) { - is ByteArray -> row.joinToString(prefix = "[", postfix = "]") { "%02X".format(it) } - - is CharArray -> row.iterator().asSequence().toList().toString() - is ShortArray -> row.iterator().asSequence().toList().toString() - is IntArray -> row.iterator().asSequence().toList().toString() - is LongArray -> row.iterator().asSequence().toList().toString() - is FloatArray -> row.iterator().asSequence().toList().toString() - is DoubleArray -> row.iterator().asSequence().toList().toString() - is BooleanArray -> row.iterator().asSequence().toList().toString() - is Array<*> -> row.iterator().asSequence().toList().toString() - is Iterable<*> -> row.iterator().asSequence().toList().toString() - is scala.collection.Iterable<*> -> row.asKotlinIterable().iterator().asSequence().toList().toString() - is Iterator<*> -> row.asSequence().toList().toString() - is scala.collection.Iterator<*> -> row.asKotlinIterator().asSequence().toList().toString() - is Product -> row.productIterator().asKotlinIterator().asSequence().toList().toString() - is Serializable -> row.toString() - // maybe others? - - is Any? -> row.toString() - else -> row.toString() - } - - +if (truncate > 0 && string.length > truncate) { - // do not show ellipses for strings shorter than 4 characters. - if (truncate < 4) string.substring(0, truncate) - else string.substring(0, truncate - 3) + "..." - } else { - string - } - } - } - - if (hasMoreData) tr { - +"only showing top $numRows ${if (numRows == 1) "row" else "rows"}" - } - } -} catch (e: SparkException) { - // Whenever toString() on the contents doesn't work, since the class might be unknown... - """${toString()} - |Cannot render this RDD of this class.""".trimMargin() -} - -internal fun Dataset.toHtml(limit: Int = 20, truncate: Int = 30): String = createHtmlTable { - val numRows = limit.coerceIn(0 until ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) - val tmpRows = getRows(numRows, truncate).asKotlinList().map { it.asKotlinList() } - - val hasMoreData = tmpRows.size - 1 > numRows - val rows = tmpRows.take(numRows + 1) - - tr { - for (header in rows.first()) th { - +if (truncate > 0 && header.length > truncate) { - // do not show ellipses for strings shorter than 4 characters. - if (truncate < 4) header.substring(0, truncate) - else header.substring(0, truncate - 3) + "..." - } else { - header - } - - } - } - - for (row in rows.drop(1)) tr { - for (item in row) td { - +item - } - } - - if (hasMoreData) tr { - +"only showing top $numRows ${if (numRows == 1) "row" else "rows"}" - } -} diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt index 30b9b27b..13b22b47 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,11 +19,6 @@ */ package org.jetbrains.kotlinx.spark.api.jupyter -import org.apache.spark.api.java.JavaDoubleRDD -import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.jupyter.api.Code @@ -52,16 +47,12 @@ import kotlin.reflect.KMutableProperty import kotlin.reflect.full.createType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubtypeOf -import kotlin.reflect.full.memberFunctions import kotlin.reflect.full.memberProperties import kotlin.reflect.full.primaryConstructor -import kotlin.reflect.full.valueParameters import kotlin.reflect.typeOf - -abstract class Integration(private val notebook: Notebook, private val options: MutableMap) : +abstract class Integration(protected val notebook: Notebook, private val options: MutableMap) : JupyterIntegration() { - protected val kotlinVersion = /*$"\""+kotlin+"\""$*/ /*-*/ "" protected val scalaCompatVersion = /*$"\""+scalaCompat+"\""$*/ /*-*/ "" protected val scalaVersion = /*$"\""+scala+"\""$*/ /*-*/ "" @@ -74,14 +65,14 @@ abstract class Integration(private val notebook: Notebook, private val options: .value .getOrThrow() as Properties - - protected open val usingProperties = arrayOf( - displayLimitName, - displayTruncateName, - sparkName, - scalaName, - versionName, - ) + protected open val usingProperties = + arrayOf( + displayLimitName, + displayTruncateName, + sparkName, + scalaName, + versionName, + ) /** * Will be run after importing all dependencies @@ -94,47 +85,51 @@ abstract class Integration(private val notebook: Notebook, private val options: open fun KotlinKernelHost.beforeCellExecution() = Unit - open fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit + open fun KotlinKernelHost.afterCellExecution( + snippetInstance: Any, + result: FieldValue, + ) = Unit open fun Builder.onLoadedAlsoDo() = Unit - open val dependencies: Array = arrayOf( - "org.apache.spark:spark-repl_$scalaCompatVersion:$sparkVersion", - "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion", - "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion", - "org.apache.spark:spark-sql_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-yarn_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-streaming_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-mllib_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-sql_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-graphx_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-launcher_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-catalyst_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-streaming_$scalaCompatVersion:$sparkVersion", - "org.apache.spark:spark-core_$scalaCompatVersion:$sparkVersion", - "org.scala-lang:scala-library:$scalaVersion", - "org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1", - "org.scala-lang:scala-reflect:$scalaVersion", - "org.scala-lang:scala-compiler:$scalaVersion", - "commons-io:commons-io:2.11.0", - ) - - open val imports: Array = arrayOf( - "org.jetbrains.kotlinx.spark.api.plugin.annotations.*", - "org.jetbrains.kotlinx.spark.api.*", - "org.jetbrains.kotlinx.spark.api.tuples.*", - *(1..22).map { "scala.Tuple$it" }.toTypedArray(), - "org.apache.spark.sql.functions.*", - "org.apache.spark.*", - "org.apache.spark.sql.*", - "org.apache.spark.api.java.*", - "scala.collection.Seq", - "org.apache.spark.rdd.*", - "java.io.Serializable", - "org.apache.spark.streaming.api.java.*", - "org.apache.spark.streaming.api.*", - "org.apache.spark.streaming.*", - ) + open val dependencies: Array = + arrayOf( + "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion", + "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion", + "org.apache.spark:spark-sql_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-yarn_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-streaming_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-mllib_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-sql_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-graphx_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-launcher_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-catalyst_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-streaming_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-core_$scalaCompatVersion:$sparkVersion", + "org.scala-lang:scala-library:$scalaVersion", + "org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1", + "org.scala-lang:scala-reflect:$scalaVersion", + "org.scala-lang:scala-compiler:$scalaVersion", + "commons-io:commons-io:2.11.0", + ) + + open val imports: Array = + arrayOf( + "org.jetbrains.kotlinx.spark.api.plugin.annotations.*", + "org.jetbrains.kotlinx.spark.api.*", + "org.jetbrains.kotlinx.spark.api.tuples.*", + *(1..22).map { "scala.Tuple$it" }.toTypedArray(), + "org.apache.spark.sql.functions.*", + "org.apache.spark.*", + "org.apache.spark.sql.*", + "org.apache.spark.api.java.*", + "scala.collection.Seq", + "org.apache.spark.rdd.*", + "java.io.Serializable", + "org.apache.spark.streaming.api.java.*", + "org.apache.spark.streaming.api.*", + "org.apache.spark.streaming.*", + ) // Needs to be set by integration var spark: SparkSession? = null @@ -144,18 +139,18 @@ abstract class Integration(private val notebook: Notebook, private val options: import(*imports) onLoaded { - val mutableOptions = options.toMutableMap() declare( VariableDeclaration( name = sparkPropertiesName, - value = object : Properties, MutableMap by mutableOptions { - override fun toString(): String = "Properties: $mutableOptions" - }, + value = + object : Properties, MutableMap by mutableOptions { + override fun toString(): String = "Properties: $mutableOptions" + }, type = typeOf(), isMutable = true, - ) + ), ) onLoaded() @@ -184,11 +179,14 @@ abstract class Integration(private val notebook: Notebook, private val options: } onClassAnnotation { - for (klass in it) { - if (klass.isData) { - execute(generateSparkifyClass(klass)) + val newClassCode = buildString { + for (klass in it) { + if (klass.isData) { + appendLine(generateSparkifyClass(klass)) + } } } + execute(newClassCode) } // Render Dataset @@ -196,45 +194,57 @@ abstract class Integration(private val notebook: Notebook, private val options: renderDataset(it) } + //#if sparkConnect == false // using compile time KType, convert this JavaRDDLike to Dataset and render it notebook.renderersProcessor.registerWithoutOptimizing( - createRendererByCompileTimeType> { + createRendererByCompileTimeType> { if (spark == null) return@createRendererByCompileTimeType it.value.toString() - val rdd = (it.value as JavaRDDLike<*, *>).rdd() - val type = when { - it.type.isSubtypeOf(typeOf()) -> - typeOf() - - it.type.isSubtypeOf(typeOf>()) -> - Tuple2::class.createType( - listOf( - it.type.arguments.first(), - it.type.arguments.last(), + val rdd = (it.value as org.apache.spark.api.java.JavaRDDLike<*, *>).rdd() + val type = + when { + it.type.isSubtypeOf(typeOf()) -> + typeOf() + + it.type.isSubtypeOf(typeOf>()) -> + Tuple2::class.createType( + listOf( + it.type.arguments.first(), + it.type.arguments.last(), + ), ) - ) - it.type.isSubtypeOf(typeOf>()) -> - it.type.arguments.first().type!! + it.type.isSubtypeOf(typeOf>()) -> + it.type.arguments + .first() + .type!! - else -> it.type.arguments.first().type!! - } + else -> + it.type.arguments + .first() + .type!! + } val ds = spark!!.createDataset(rdd, kotlinEncoderFor(type)) renderDataset(ds) - } + }, ) - + //#endif + //#if sparkConnect == false // using compile time KType, convert this RDD to Dataset and render it notebook.renderersProcessor.registerWithoutOptimizing( - createRendererByCompileTimeType> { + createRendererByCompileTimeType> { if (spark == null) return@createRendererByCompileTimeType it.value.toString() - val rdd = it.value as RDD<*> - val type = it.type.arguments.first().type!! + val rdd = it.value as org.apache.spark.rdd.RDD<*> + val type = + it.type.arguments + .first() + .type!! val ds = spark!!.createDataset(rdd, kotlinEncoderFor(type)) renderDataset(ds) - } + }, ) + //#endif onLoadedAlsoDo() } @@ -260,7 +270,6 @@ abstract class Integration(private val notebook: Notebook, private val options: textResult("") } - // TODO wip private fun generateSparkifyClass(klass: KClass<*>): Code { // val name = "`${klass.simpleName!!}${'$'}Generated`" @@ -269,26 +278,30 @@ abstract class Integration(private val notebook: Notebook, private val options: val visibility = klass.visibility?.name?.lowercase() ?: "" val memberProperties = klass.memberProperties - val properties = constructorArgs.associateWith { - memberProperties.first { it.name == it.name } - } + val properties = + constructorArgs.associateWith { + memberProperties.first { it.name == it.name } + } - val constructorParamsCode = properties.entries.joinToString("\n") { (param, prop) -> - // TODO check override - if (param.isOptional) TODO() - val modifier = if (prop is KMutableProperty<*>) "var" else "val" - val paramVisiblity = prop.visibility?.name?.lowercase() ?: "" - val columnName = param.findAnnotation()?.name ?: param.name!! + val constructorParamsCode = + properties.entries.joinToString("\n") { (param, prop) -> + // TODO check override + if (param.isOptional) TODO() + val modifier = if (prop is KMutableProperty<*>) "var" else "val" + val paramVisiblity = prop.visibility?.name?.lowercase() ?: "" + val columnName = param.findAnnotation()?.name ?: param.name!! - "| @get:kotlin.jvm.JvmName(\"$columnName\") $paramVisiblity $modifier ${param.name}: ${param.type}," - } + "| @get:kotlin.jvm.JvmName(\"$columnName\") $paramVisiblity $modifier ${param.name}: ${param.type}," + } - val productElementWhenParamsCode = properties.entries.joinToString("\n") { (param, _) -> - "| ${param.index} -> this.${param.name}" - } + val productElementWhenParamsCode = + properties.entries.joinToString("\n") { (param, _) -> + "| ${param.index} -> this.${param.name}" + } @Language("kotlin") - val code = """ + val code = + """ |$visibility data class $name( $constructorParamsCode |): scala.Product, java.io.Serializable { @@ -299,7 +312,7 @@ abstract class Integration(private val notebook: Notebook, private val options: | else -> throw IndexOutOfBoundsException() | } |} - """.trimMargin() + """.trimMargin() return code } } diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/OldSparkIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/OldSparkIntegration.kt new file mode 100644 index 00000000..a7301933 --- /dev/null +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/OldSparkIntegration.kt @@ -0,0 +1,140 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 3.2+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2022 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +@file:Suppress("UsePropertyAccessSyntax") + +package org.jetbrains.kotlinx.spark.api.jupyter + +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost +import org.jetbrains.kotlinx.jupyter.api.Notebook +import org.jetbrains.kotlinx.spark.api.SparkSession +import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.appNameName +import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkMasterName + +/** + * %use spark + */ +@Suppress("UNUSED_VARIABLE", "LocalVariableName") +class OldSparkIntegration( + notebook: Notebook, + options: MutableMap, +) : Integration(notebook, options) { + override fun KotlinKernelHost.onLoaded() { + val _0 = execute("""%dumpClassesForSpark""") + + properties { + putIfAbsent(sparkMasterName, "local[*]") + putIfAbsent(appNameName, "Kotlin Spark API - Jupyter") + putIfAbsent("spark.sql.codegen.wholeStage", "false") + putIfAbsent("fs.hdfs.impl", org.apache.hadoop.hdfs.DistributedFileSystem::class.java.name) + putIfAbsent("fs.file.impl", org.apache.hadoop.fs.LocalFileSystem::class.java.name) + } + + @Language("kts") + val _1 = + listOf( + """ + val spark = org.jetbrains.kotlinx.spark.api.SparkSession + .builder() + .apply { + ${ + buildString { + val sparkProps = properties.filterKeys { it !in usingProperties } + println("received properties: $properties, providing Spark with: $sparkProps") + + sparkProps.forEach { (key, value) -> + appendLine("config(\"${key}\", \"$value\")") + } + } + } + } + .getOrCreate() + """.trimIndent(), + """ + spark.sparkContext.setLogLevel(org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR) + """.trimIndent(), + """ + val sc by lazy { + org.apache.spark.api.java.JavaSparkContext(spark.sparkContext) + } + """.trimIndent(), + """ + println("Spark session (Spark: $sparkVersion, Scala: $scalaCompatVersion, v: $version) has been started and is running. No `withSpark { }` necessary, you can access `spark` and `sc` directly. To use Spark streaming, use `%use spark-streaming` instead.") + """.trimIndent(), + """ + inline fun List.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun List.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + inline fun Array.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun Array.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + inline fun dsOf(vararg arg: T): Dataset = spark.dsOf(*arg) + """.trimIndent(), + """ + inline fun dfOf(vararg arg: T): Dataset = spark.dfOf(*arg) + """.trimIndent(), + """ + inline fun emptyDataset(): Dataset = spark.emptyDataset(kotlinEncoderFor()) + """.trimIndent(), + """ + inline fun dfOf(colNames: Array, vararg arg: T): Dataset = spark.dfOf(colNames, *arg) + """.trimIndent(), + """ + inline fun RDD.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun JavaRDDLike.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun RDD.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + inline fun JavaRDDLike.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + fun List.toRDD(numSlices: Int = sc.defaultParallelism()): JavaRDD = sc.toRDD(this, numSlices) + """.trimIndent(), + """ + fun rddOf(vararg elements: T, numSlices: Int = sc.defaultParallelism()): JavaRDD = sc.toRDD(elements.toList(), numSlices) + """.trimIndent(), + """ + val udf: UDFRegistration get() = spark.udf() + """.trimIndent(), + """ + inline fun > NAMED_UDF.register(): NAMED_UDF = spark.udf().register(namedUdf = this) + """.trimIndent(), + """ + inline fun > UserDefinedFunction.register(name: String): NAMED_UDF = spark.udf().register(name = name, udf = this) + """.trimIndent(), + ).map(::execute) + + spark = execute("spark").value as SparkSession + } + + override fun KotlinKernelHost.onShutdown() { + execute("""spark.stop()""") + } +} diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Properties.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Properties.kt index 4c06edc6..7c7507c6 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Properties.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Properties.kt @@ -1,11 +1,5 @@ package org.jetbrains.kotlinx.spark.api.jupyter -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.json.put - interface Properties : MutableMap { companion object { @@ -18,16 +12,20 @@ interface Properties : MutableMap { internal const val versionName = "v" internal const val displayLimitName = "displayLimit" internal const val displayTruncateName = "displayTruncate" + internal const val remoteName = "remote" + internal const val debugName = "debug" } - /** The value which limits the number of rows while displaying an RDD or Dataset. + /** + * The value which limits the number of rows while displaying an RDD or Dataset. * Default: 20 */ var displayLimit: Int set(value) { this[displayLimitName] = value.toString() } get() = this[displayLimitName]?.toIntOrNull() ?: 20 - /** The value which limits the number characters per cell while displaying an RDD or Dataset. + /** + * The value which limits the number characters per cell while displaying an RDD or Dataset. * `-1` for no limit. * Default: 30 */ @@ -35,6 +33,9 @@ interface Properties : MutableMap { set(value) { this[displayTruncateName] = value.toString() } get() = this[displayTruncateName]?.toIntOrNull() ?: 30 + var debug: Boolean + set(value) { this[debugName] = value.toString() } + get() = this[debugName]?.toBoolean() ?: false operator fun invoke(block: Properties.() -> Unit): Properties = apply(block) } diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt index 0c4eb096..11117ed6 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,100 +21,232 @@ package org.jetbrains.kotlinx.spark.api.jupyter - +import org.apache.spark.sql.connect.client.Artifact +import org.apache.spark.sql.connect.client.`Artifact$` +import org.apache.spark.sql.connect.client.ClassFinder import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.jupyter.api.CodePreprocessor import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost import org.jetbrains.kotlinx.jupyter.api.Notebook +import org.jetbrains.kotlinx.jupyter.api.VariableDeclaration +import org.jetbrains.kotlinx.jupyter.api.declare import org.jetbrains.kotlinx.spark.api.SparkSession -import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.appNameName -import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkMasterName - +import org.jetbrains.kotlinx.spark.api.asScalaIterator +import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.remoteName +import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkPropertiesName +import org.jetbrains.kotlinx.spark.api.map +import scala.collection.Iterator +import java.net.URLClassLoader +import java.nio.file.Path +import kotlin.io.path.Path +import kotlin.io.path.extension +import kotlin.io.path.name +import kotlin.reflect.typeOf /** + * Spark connect! + * * %use spark */ @Suppress("UNUSED_VARIABLE", "LocalVariableName") -@OptIn(ExperimentalStdlibApi::class) -class SparkIntegration(notebook: Notebook, options: MutableMap) : Integration(notebook, options) { +class SparkIntegration(notebook: Notebook, options: MutableMap) : + Integration(notebook, options), + CodePreprocessor, + ClassFinder { + override val usingProperties: Array + get() = super.usingProperties + remoteName + + override val dependencies: Array = + arrayOf( + "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion", + "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion", + "org.apache.spark:spark-sql-api_$scalaCompatVersion:$sparkVersion", + "org.apache.spark:spark-connect-client-jvm_$scalaCompatVersion:$sparkVersion", + "org.scala-lang:scala-library:$scalaVersion", + "org.scala-lang:scala-reflect:$scalaVersion", + "commons-io:commons-io:2.11.0", + ) + + override val imports: Array = + arrayOf( + "org.jetbrains.kotlinx.spark.api.plugin.annotations.*", + "org.jetbrains.kotlinx.spark.api.*", + "org.jetbrains.kotlinx.spark.api.tuples.*", + *(1..22).map { "scala.Tuple$it" }.toTypedArray(), + "org.apache.spark.sql.functions.*", + "org.apache.spark.*", + "org.apache.spark.sql.*", + "org.apache.spark.api.java.*", + "scala.collection.Seq", + "java.io.Serializable", + ) + + private val dumpedClasses = ClassCache() + + override fun findClasses(): Iterator = + dumpedClasses + .map { + try { + `Artifact$`.`MODULE$`.newClassArtifact(it.fileName, Artifact.LocalFile(it)) + } catch (e: Exception) { + throw RuntimeException("Error while creating class artifact for $it", e) + } + }.iterator() + .asScalaIterator() override fun KotlinKernelHost.onLoaded() { val _0 = execute("""%dumpClassesForSpark""") properties { - getOrPut(sparkMasterName) { "local[*]" } - getOrPut(appNameName) { "Kotlin Spark API - Jupyter" } - getOrPut("spark.sql.codegen.wholeStage") { "false" } - getOrPut("fs.hdfs.impl") { org.apache.hadoop.hdfs.DistributedFileSystem::class.java.name } - getOrPut("fs.file.impl") { org.apache.hadoop.fs.LocalFileSystem::class.java.name } + putIfAbsent(remoteName, "sc://localhost") + putIfAbsent("spark.sql.legacy.allowUntypedScalaUDF", "true") +// getOrPut("spark.sql.codegen.wholeStage") { "false" } + putIfAbsent("fs.hdfs.impl", org.apache.hadoop.hdfs.DistributedFileSystem::class.java.name) + putIfAbsent("fs.file.impl", org.apache.hadoop.fs.LocalFileSystem::class.java.name) } + declare( + VariableDeclaration( + name = ::dumpedClasses.name, + value = dumpedClasses, + type = typeOf(), + ), + ) + @Language("kts") - val _1 = listOf( - """ + val _1 = + listOf( + """ val spark = org.jetbrains.kotlinx.spark.api.SparkSession .builder() + .remote("${properties[remoteName]}") .apply { ${ - buildString { - val sparkProps = properties.filterKeys { it !in usingProperties } - println("received properties: $properties, providing Spark with: $sparkProps") + buildString { + val sparkProps = properties.filterKeys { it !in usingProperties } + println("received properties: $properties, providing Spark with: $sparkProps") - sparkProps.forEach { (key, value) -> - appendLine("config(\"${key}\", \"$value\")") + sparkProps.forEach { (key, value) -> + appendLine("config(\"${key}\", \"$value\")") + } } } - } } - .getOrCreate()""".trimIndent(), - """ - spark.sparkContext.setLogLevel(org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR)""".trimIndent(), - """ - val sc by lazy { - org.apache.spark.api.java.JavaSparkContext(spark.sparkContext) - }""".trimIndent(), - """ - println("Spark session (Spark: $sparkVersion, Scala: $scalaCompatVersion, v: $version) has been started and is running. No `withSpark { }` necessary, you can access `spark` and `sc` directly. To use Spark streaming, use `%use spark-streaming` instead.")""".trimIndent(), - """ - inline fun List.toDS(): Dataset = toDS(spark)""".trimIndent(), - """ - inline fun List.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames)""".trimIndent(), - """ - inline fun Array.toDS(): Dataset = toDS(spark)""".trimIndent(), - """ - inline fun Array.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames)""".trimIndent(), - """ - inline fun dsOf(vararg arg: T): Dataset = spark.dsOf(*arg)""".trimIndent(), - """ - inline fun dfOf(vararg arg: T): Dataset = spark.dfOf(*arg)""".trimIndent(), - """ - inline fun emptyDataset(): Dataset = spark.emptyDataset(kotlinEncoderFor())""".trimIndent(), - """ - inline fun dfOf(colNames: Array, vararg arg: T): Dataset = spark.dfOf(colNames, *arg)""".trimIndent(), - """ - inline fun RDD.toDS(): Dataset = toDS(spark)""".trimIndent(), - """ - inline fun JavaRDDLike.toDS(): Dataset = toDS(spark)""".trimIndent(), - """ - inline fun RDD.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames)""".trimIndent(), - """ - inline fun JavaRDDLike.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames)""".trimIndent(), - """ - fun List.toRDD(numSlices: Int = sc.defaultParallelism()): JavaRDD = sc.toRDD(this, numSlices)""".trimIndent(), - """ - fun rddOf(vararg elements: T, numSlices: Int = sc.defaultParallelism()): JavaRDD = sc.toRDD(elements.toList(), numSlices)""".trimIndent(), - """ - val udf: UDFRegistration get() = spark.udf()""".trimIndent(), - """ - inline fun > NAMED_UDF.register(): NAMED_UDF = spark.udf().register(namedUdf = this)""".trimIndent(), - """ - inline fun > UserDefinedFunction.register(name: String): NAMED_UDF = spark.udf().register(name = name, udf = this)""".trimIndent(), - ).map(::execute) + .getOrCreate() + """.trimIndent(), + """ + println("Spark Connect session (Spark: $sparkVersion, Scala: $scalaCompatVersion, v: $version, remote: ${properties[remoteName]}) has been started and is running. No `withSparkConnect { }` necessary, you can access `spark` directly.") + """.trimIndent(), + """ + inline fun List.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun List.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + inline fun Array.toDS(): Dataset = toDS(spark) + """.trimIndent(), + """ + inline fun Array.toDF(vararg colNames: String): Dataset = toDF(spark, *colNames) + """.trimIndent(), + """ + inline fun dsOf(vararg arg: T): Dataset = spark.dsOf(*arg) + """.trimIndent(), + """ + inline fun dfOf(vararg arg: T): Dataset = spark.dfOf(*arg) + """.trimIndent(), + """ + inline fun emptyDataset(): Dataset = spark.emptyDataset(kotlinEncoderFor()) + """.trimIndent(), + """ + inline fun dfOf(colNames: Array, vararg arg: T): Dataset = spark.dfOf(colNames, *arg) + """.trimIndent(), + """ + val udf: UDFRegistration get() = spark.udf() + """.trimIndent(), + """ + inline fun > NAMED_UDF.register(): NAMED_UDF = spark.udf().register(namedUdf = this) + """.trimIndent(), + """ + inline fun > UserDefinedFunction.register(name: String): NAMED_UDF = spark.udf().register(name = name, udf = this) + """.trimIndent(), + """ + /** This function is run automatically at the beginning of each cell to make its .class contents available to Spark. */ + fun dumpClassesToSpark() { + val outputFiles = java.io.File(System.getProperty("spark.repl.class.outputDir")) + .listFiles { it -> it.extension == "class" } + ?.map { it.toPath() } + if (outputFiles != null) { + dumpedClasses += outputFiles + if ($sparkPropertiesName.debug) println("Dumped classes: ${'$'}outputFiles") + } + } + """.trimIndent(), + ).map(::execute) spark = execute("spark").value as SparkSession - } - override fun KotlinKernelHost.onShutdown() { - execute("""spark.stop()""") + // Add all jars in the classpath to Spark as artifacts + buildList { + var current: ClassLoader? = execute("this::class.java.classLoader").value as ClassLoader + while (current != null) { + add(current) + current = current.parent + } + }.filterIsInstance() + .flatMap { it.getURLs().map { it.path } } + .filter { it.endsWith(".jar") } + .forEach { + try { + spark!!.addArtifact(it) + } catch (e: Exception) { + println("Error while adding artifact $it: $e") + } + } + + spark!!.registerClassFinder(this@SparkIntegration) + notebook.codePreprocessorsProcessor.register(this@SparkIntegration) } + + /** + * Makes it so that `dumpClassesToSpark()` is run automatically + * at the beginning of each cell. + */ + override fun process( + code: String, + host: KotlinKernelHost, + ): CodePreprocessor.Result = + CodePreprocessor.Result( + code + .lines() + .toMutableList() + .also { + it.add( + index = it.indexOfLast { it.startsWith("import ") } + 1, + element = "dumpClassesToSpark()", + ) + }.joinToString("\n"), + ) } +/** + * Spark connect .class cache designed to keep the newest class files (unique by name) with their + * given path. + */ +class ClassCache : Iterable { + private val cache: MutableMap = mutableMapOf() + + fun add(path: Path) { + require(path.extension == "class") { "Must be a .class file" } + val name = path.fileName.toString() + cache[name] = path + } + + fun addAll(paths: Iterable) = paths.forEach { add(it) } + + operator fun plusAssign(path: Path) = add(path) + + operator fun plusAssign(path: Iterable) = addAll(path) + + override fun iterator(): kotlin.collections.Iterator = cache.values.iterator() +} diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt deleted file mode 100644 index 122e3122..00000000 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt +++ /dev/null @@ -1,144 +0,0 @@ -/*- - * =LICENSE= - * Kotlin Spark API: API for Spark 3.2+ (Scala 2.12) - * ---------- - * Copyright (C) 2019 - 2022 JetBrains - * ---------- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * =LICENSEEND= - */ -package org.jetbrains.kotlinx.spark.api.jupyter - - -import org.apache.spark.streaming.StreamingContextState -import org.apache.spark.streaming.api.java.JavaStreamingContext -import org.intellij.lang.annotations.Language -import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost -import org.jetbrains.kotlinx.jupyter.api.Notebook -import org.jetbrains.kotlinx.jupyter.api.VariableDeclaration -import org.jetbrains.kotlinx.jupyter.api.declare -import kotlin.reflect.typeOf - -/** - * %use spark-streaming - */ -@Suppress("UNUSED_VARIABLE", "LocalVariableName") -class SparkStreamingIntegration(notebook: Notebook, options: MutableMap) : Integration(notebook, options) { - - override val imports: Array = super.imports + arrayOf( - "org.apache.spark.deploy.SparkHadoopUtil", - "org.apache.hadoop.conf.Configuration", - ) - - private val sscCollection = mutableSetOf() - - override fun KotlinKernelHost.onLoaded() { - - declare( - VariableDeclaration( - name = ::sscCollection.name, - value = sscCollection, - isMutable = false, - type = typeOf>(), - ) - ) - - val _0 = execute("""%dumpClassesForSpark""") - - @Language("kts") - val _1 = listOf( - """ - @JvmOverloads - fun withSparkStreaming( - batchDuration: Duration = Durations.seconds(1L), - checkpointPath: String? = null, - hadoopConf: Configuration = SparkHadoopUtil.get().conf(), - createOnError: Boolean = false, - props: Map = emptyMap(), - master: String = SparkConf().get("spark.master", "local[*]"), - appName: String = "Kotlin Spark Sample", - timeout: Long = -1L, - startStreamingContext: Boolean = true, - func: KSparkStreamingSession.() -> Unit, - ) { - - // will only be set when a new context is created - var kSparkStreamingSession: KSparkStreamingSession? = null - - val creatingFunc = { - val sc = SparkConf() - .setAppName(appName) - .setMaster(master) - .setAll( - props - .map { (key, value) -> key X value.toString() } - .asScalaIterable() - ) - - val ssc = JavaStreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointPath) - - kSparkStreamingSession = KSparkStreamingSession(ssc) - func(kSparkStreamingSession!!) - - ssc - } - - val ssc = when { - checkpointPath != null -> - JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError) - - else -> creatingFunc() - } - sscCollection += ssc - - if (startStreamingContext) { - ssc.start() - kSparkStreamingSession?.invokeRunAfterStart() - } - ssc.awaitTerminationOrTimeout(timeout) - ssc.stop() - } - """.trimIndent(), - """ - println("To start a Spark (Spark: $sparkVersion, Scala: $scalaCompatVersion, v: $version) Streaming session, simply use `withSparkStreaming { }` inside a cell. To use Spark normally, use `withSpark { }` in a cell, or use `%use spark` to start a Spark session for the whole notebook.")""".trimIndent(), - ).map(::execute) - } - - private fun cleanUp(e: Throwable): String { - while (sscCollection.isNotEmpty()) - sscCollection.first().let { - while (it.state != StreamingContextState.STOPPED) { - try { - it.stop(true, true) - } catch (_: Exception) { - } - } - sscCollection.remove(it) - } - - return "Spark streams cleaned up. Cause: $e" - } - - override fun Builder.onLoadedAlsoDo() { - renderThrowable { - cleanUp(it) - } - } - - override fun KotlinKernelHost.onInterrupt() { - println( - cleanUp(InterruptedException("Kernel was interrupted.")) - ) - } -} diff --git a/jupyter/src/main/resources/kotest.properties b/jupyter/src/main/resources/kotest.properties new file mode 100644 index 00000000..77319870 --- /dev/null +++ b/jupyter/src/main/resources/kotest.properties @@ -0,0 +1,2 @@ +kotest.framework.classpath.scanning.config.disable=true +kotest.framework.classpath.scanning.autoscan.disable=true \ No newline at end of file diff --git a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt deleted file mode 100644 index 9368ebc4..00000000 --- a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt +++ /dev/null @@ -1,478 +0,0 @@ -/*- - * =LICENSE= - * Kotlin Spark API: API for Spark 3.2+ (Scala 2.12) - * ---------- - * Copyright (C) 2019 - 2022 JetBrains - * ---------- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * =LICENSEEND= - */ -package org.jetbrains.kotlinx.spark.api.jupyter - -import io.kotest.assertions.throwables.shouldThrowAny -import io.kotest.core.spec.style.ShouldSpec -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.shouldBe -import io.kotest.matchers.shouldNotBe -import io.kotest.matchers.string.shouldContain -import io.kotest.matchers.string.shouldNotContain -import io.kotest.matchers.types.shouldBeInstanceOf -import jupyter.kotlin.DependsOn -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.streaming.api.java.JavaStreamingContext -import org.intellij.lang.annotations.Language -import org.jetbrains.kotlinx.jupyter.api.Code -import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult -import org.jetbrains.kotlinx.jupyter.api.MimeTypes -import org.jetbrains.kotlinx.jupyter.libraries.createLibraryHttpUtil -import org.jetbrains.kotlinx.jupyter.repl.EvalRequestData -import org.jetbrains.kotlinx.jupyter.repl.ReplForJupyter -import org.jetbrains.kotlinx.jupyter.repl.creating.createRepl -import org.jetbrains.kotlinx.jupyter.repl.result.EvalResultEx -import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider -import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule -import org.jetbrains.kotlinx.spark.api.SparkSession -import java.io.Serializable -import kotlin.script.experimental.jvm.util.classpathFromClassloader - - -class JupyterTests : ShouldSpec({ - - val replProvider = ReplProvider { classpath -> - createRepl( - httpUtil = createLibraryHttpUtil(), - scriptClasspath = classpath, - isEmbedded = true, - ).apply { - eval { - librariesScanner.addLibrariesFromClassLoader( - classLoader = currentClassLoader, - host = this, - notebook = notebook, - integrationTypeNameRules = listOf( - PatternNameAcceptanceRule( - acceptsFlag = false, - pattern = "org.jetbrains.kotlinx.spark.api.jupyter.**", - ), - PatternNameAcceptanceRule( - acceptsFlag = true, - pattern = "org.jetbrains.kotlinx.spark.api.jupyter.SparkIntegration", - ), - ), - ) - } - } - } - - val currentClassLoader = DependsOn::class.java.classLoader - val scriptClasspath = classpathFromClassloader(currentClassLoader).orEmpty() - - fun createRepl(): ReplForJupyter = replProvider(scriptClasspath) - suspend fun withRepl(action: suspend ReplForJupyter.() -> Unit): Unit = createRepl().action() - - context("Jupyter") { - withRepl { - exec("%trackExecution") - - should("Allow functions on local data classes") { - @Language("kts") - val klass = exec("""@Sparkify data class Test(val a: Int, val b: String)""") - - @Language("kts") - val ds = exec("""val ds = dsOf(Test(1, "hi"), Test(2, "something"))""") - - @Language("kts") - val filtered = exec("""val filtered = ds.filter { it.a > 1 }""") - - @Language("kts") - val filteredShow = exec("""filtered.show()""") - } - - should("Have spark instance") { - @Language("kts") - val spark = exec("""spark""") - spark as? SparkSession shouldNotBe null - } - - should("Have JavaSparkContext instance") { - @Language("kts") - val sc = exec("""sc""") - sc as? JavaSparkContext shouldNotBe null - } - - xshould("render Datasets") { - @Language("kts") - val html = execForDisplayText( - """ - val ds = listOf(1, 2, 3).toDS() - ds - """.trimIndent() - ) - println(html) - - html shouldContain "value" - html shouldContain "1" - html shouldContain "2" - html shouldContain "3" - } - - xshould("render JavaRDDs") { - @Language("kts") - val html = execForDisplayText( - """ - val rdd: JavaRDD> = listOf( - listOf(1, 2, 3), - listOf(4, 5, 6), - ).toRDD() - rdd - """.trimIndent() - ) - println(html) - - html shouldContain "1, 2, 3" - html shouldContain "4, 5, 6" - } - - xshould("render JavaRDDs with Arrays") { - @Language("kts") - val html = execForDisplayText( - """ - val rdd: JavaRDD = rddOf( - intArrayOf(1, 2, 3), - intArrayOf(4, 5, 6), - ) - rdd - """.trimIndent() - ) - println(html) - - html shouldContain "1, 2, 3" - html shouldContain "4, 5, 6" - } - - xshould("render JavaRDDs with custom class") { - - @Language("kts") - val klass = exec( - """ - @Sparkify data class Test( - val longFirstName: String, - val second: LongArray, - val somethingSpecial: Map, - ): Serializable - """.trimIndent() - ) - - @Language("kts") - val html = execForDisplayText( - """ - val rdd = - listOf( - Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")), - Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")), - ).toRDD() - - rdd - """.trimIndent() - ) - html shouldContain """ - +-------------+---------------+--------------------+ - |longFirstName| second| somethingSpecial| - +-------------+---------------+--------------------+ - | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| - | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| - +-------------+---------------+--------------------+""".trimIndent() - } - - xshould("render JavaPairRDDs") { - @Language("kts") - val html = execForDisplayText( - """ - val rdd: JavaPairRDD = rddOf( - t(1, 2), - t(3, 4), - ).toJavaPairRDD() - rdd - """.trimIndent() - ) - println(html) - - html shouldContain """ - +---+---+ - | _1| _2| - +---+---+ - | 1| 2| - | 3| 4| - +---+---+""".trimIndent() - } - - xshould("render JavaDoubleRDD") { - @Language("kts") - val html = execForDisplayText( - """ - val rdd: JavaDoubleRDD = rddOf(1.0, 2.0, 3.0, 4.0,).toJavaDoubleRDD() - rdd - """.trimIndent() - ) - println(html) - - html shouldContain "1.0" - html shouldContain "2.0" - html shouldContain "3.0" - html shouldContain "4.0" - } - - xshould("render Scala RDD") { - @Language("kts") - val html = execForDisplayText( - """ - val rdd: RDD> = rddOf( - listOf(1, 2, 3), - listOf(4, 5, 6), - ).rdd() - rdd - """.trimIndent() - ) - println(html) - - html shouldContain "1, 2, 3" - html shouldContain "4, 5, 6" - } - - xshould("truncate dataset cells using properties") { - - @Language("kts") - val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int - - @Language("kts") - val html = execForDisplayText( - """ - @Sparkify data class Test(val a: String) - sparkProperties.displayTruncate = 3 - dsOf(Test("aaaaaaaaaa")) - """.trimIndent() - ) - - @Language("kts") - val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") - - html shouldContain "aaa" - html shouldNotContain "aaaaaaaaaa" - } - - xshould("limit dataset rows using properties") { - - @Language("kts") - val oldLimit = exec("""sparkProperties.displayLimit""") as Int - - @Language("kts") - val html = execForDisplayText( - """ - @Sparkify data class Test(val a: String) - sparkProperties.displayLimit = 3 - dsOf(Test("a"), Test("b"), Test("c"), Test("d"), Test("e")) - """.trimIndent() - ) - - @Language("kts") - val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") - - html shouldContain "a|" - html shouldContain "b|" - html shouldContain "c|" - html shouldNotContain "d|" - html shouldNotContain "e|" - } - - xshould("truncate rdd cells using properties") { - - @Language("kts") - val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int - - @Language("kts") - val html = execForDisplayText( - """ - sparkProperties.displayTruncate = 3 - rddOf("aaaaaaaaaa") - """.trimIndent() - ) - - @Language("kts") - val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") - - html shouldContain "aaa" - html shouldNotContain "aaaaaaaaaa" - } - - xshould("limit rdd rows using properties") { - - @Language("kts") - val oldLimit = exec("""sparkProperties.displayLimit""") as Int - - @Language("kts") - val html = execForDisplayText( - """ - sparkProperties.displayLimit = 3 - rddOf("a", "b", "c", "d", "e") - """.trimIndent() - ) - - @Language("kts") - val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") - - html shouldContain " a|" - html shouldContain " b|" - html shouldContain " c|" - html shouldNotContain " d|" - html shouldNotContain " e|" - } - - @Language("kts") - val _stop = exec("""spark.stop()""") - } - } -}) - -class JupyterStreamingTests : ShouldSpec({ - val replProvider = ReplProvider { classpath -> - createRepl( - httpUtil = createLibraryHttpUtil(), - scriptClasspath = classpath, - isEmbedded = true, - ).apply { - eval { - librariesScanner.addLibrariesFromClassLoader( - classLoader = currentClassLoader, - host = this, - notebook = notebook, - integrationTypeNameRules = listOf( - PatternNameAcceptanceRule( - acceptsFlag = false, - pattern = "org.jetbrains.kotlinx.spark.api.jupyter.**", - ), - PatternNameAcceptanceRule( - acceptsFlag = true, - pattern = "org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration", - ), - ), - ) - } - } - } - - val currentClassLoader = DependsOn::class.java.classLoader - val scriptClasspath = classpathFromClassloader(currentClassLoader).orEmpty() - - fun createRepl(): ReplForJupyter = replProvider(scriptClasspath) - suspend fun withRepl(action: suspend ReplForJupyter.() -> Unit): Unit = createRepl().action() - - xcontext("Jupyter") { - withRepl { - - // For when onInterrupt is implemented in the Jupyter kernel - should("Have sscCollection instance") { - - @Language("kts") - val sscCollection = exec("""sscCollection""") - sscCollection as? MutableSet shouldNotBe null - } - - should("Not have spark instance") { - shouldThrowAny { - @Language("kts") - val spark = exec("""spark""") - Unit - } - } - - should("Not have sc instance") { - shouldThrowAny { - @Language("kts") - val sc = exec("""sc""") - Unit - } - } - - should("stream") { - - @Language("kts") - val value = exec( - """ - import java.util.LinkedList - import org.apache.spark.api.java.function.ForeachFunction - import org.apache.spark.util.LongAccumulator - - - val input = arrayListOf("aaa", "bbb", "aaa", "ccc") - - @Volatile - var counter: LongAccumulator? = null - - withSparkStreaming(Duration(10), timeout = 1_000) { - - val queue = withSpark(ssc) { - LinkedList(listOf(sc.parallelize(input))) - } - - val inputStream = ssc.queueStream(queue) - - inputStream.foreachRDD { rdd, _ -> - withSpark(rdd) { - if (counter == null) - counter = sc.sc().longAccumulator() - - rdd.toDS().showDS().forEach { - if (it !in input) error(it + " should be in input") - counter!!.add(1L) - } - } - } - } - counter!!.sum() - """.trimIndent() - ) as Long - - value shouldBe 4L - } - - } - } -}) - - -private fun ReplForJupyter.execEx(code: Code): EvalResultEx = evalEx(EvalRequestData(code)) - -private fun ReplForJupyter.exec(code: Code): Any? = (execEx(code) as? EvalResultEx.Success)?.renderedValue - -@JvmName("execTyped") -private inline fun ReplForJupyter.exec(code: Code): T { - val res = exec(code) - res.shouldBeInstanceOf() - return res -} - -private fun ReplForJupyter.execHtml(code: Code): String { - val res = exec(code) - val html = res["text/html"] - html.shouldNotBeNull() - return html -} - -private fun ReplForJupyter.execForDisplayText(code: Code): String { - val res = exec(code) - val text = res[MimeTypes.PLAIN_TEXT] - text.shouldNotBeNull() - return text -} - -class Counter(@Volatile var value: Int) : Serializable diff --git a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkConnectTests.kt b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkConnectTests.kt new file mode 100644 index 00000000..56f467fd --- /dev/null +++ b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkConnectTests.kt @@ -0,0 +1,536 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 3.2+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2022 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +package org.jetbrains.kotlinx.spark.api.jupyter + +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.string.shouldNotContain +import io.kotest.matchers.types.shouldBeInstanceOf +import io.kotest.matchers.types.shouldNotBeTypeOf +import jupyter.kotlin.DependsOn +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.jupyter.api.Code +import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult +import org.jetbrains.kotlinx.jupyter.api.MimeTypes +import org.jetbrains.kotlinx.jupyter.libraries.createLibraryHttpUtil +import org.jetbrains.kotlinx.jupyter.repl.EvalRequestData +import org.jetbrains.kotlinx.jupyter.repl.ReplForJupyter +import org.jetbrains.kotlinx.jupyter.repl.creating.createRepl +import org.jetbrains.kotlinx.jupyter.repl.result.EvalResultEx +import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider +import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule +import org.jetbrains.kotlinx.spark.api.SparkSession +import java.io.Serializable +import kotlin.script.experimental.jvm.util.classpathFromClassloader + +class SparkConnectTests : + ShouldSpec({ + + val replProvider = ReplProvider { classpath -> + createRepl( + httpUtil = createLibraryHttpUtil(), + scriptClasspath = classpath, + isEmbedded = true, + ).apply { + eval { + librariesScanner.addLibrariesFromClassLoader( + classLoader = currentClassLoader, + host = this, + notebook = notebook, + libraryOptions = mapOf( + "spark" to "3.5.1", + "scala" to "2.13", + // Make sure spark connect is running at sc://localhost! + "remote" to "sc://localhost", + ), + integrationTypeNameRules = listOf( + PatternNameAcceptanceRule( + acceptsFlag = false, + pattern = "org.jetbrains.kotlinx.spark.api.jupyter.**", + ), + PatternNameAcceptanceRule( + acceptsFlag = true, + pattern = "org.jetbrains.kotlinx.spark.api.jupyter.SparkIntegration", + ), + ), + ) + } + } + } + + val currentClassLoader = DependsOn::class.java.classLoader + val scriptClasspath = classpathFromClassloader(currentClassLoader).orEmpty() + + fun createRepl(): ReplForJupyter = replProvider(scriptClasspath) + + suspend fun withRepl(action: suspend ReplForJupyter.() -> Unit): Unit = createRepl().action() + + context("Spark Connect") { + + withRepl { + exec("%trackExecution") + exec("sparkProperties.debug = true") + + should("Allow functions on local data classes") { + @Language("kts") + val klass = execEx("""@Sparkify data class Test(val a: Int, val b: String)""") + klass.shouldNotBeTypeOf() + + @Language("kts") + val ds = execEx("""val ds = dsOf(Test(1, "hi"), Test(2, "something")).filter { it.a >= 1 }.showDS()""") + ds.shouldNotBeTypeOf() + + @Language("kts") + val ds2 = execEx("""val ds = dsOf(Test(1, "hi"), Test(2, "something")).filter { it.a >= 1 }.showDS()""") + ds2.shouldNotBeTypeOf() + + @Language("kts") + val filtered = execEx("""val filtered = ds.filter { it.a > 1 }""") + filtered.shouldNotBeTypeOf() + + @Language("kts") + val filtered2 = execEx("""val filtered = ds.filter { it.a > 1 }""") + filtered2.shouldNotBeTypeOf() + + @Language("kts") + val filteredShow = execEx("""filtered.showDS()""") + filteredShow.shouldNotBeTypeOf() + + @Language("kts") + val filteredShow2 = execEx("""filtered.showDS()""") + filteredShow.shouldNotBeTypeOf() + + @Language("kts") + val toList = execEx("""ds.toList()""") + toList.shouldNotBeTypeOf() + + println(toList.toString()) + } + + should("Not work in the same cell") { + @Language("kts") + val klass = + execEx( + """ + @Sparkify data class Test(val a: Int, val b: String) + val ds = dsOf(Test(1, "hi"), Test(2, "something")) + ds + """.trimIndent(), + ) + klass.shouldBeInstanceOf() + } + + should("Have spark instance") { + @Language("kts") + val spark = exec("""spark""") + spark as? SparkSession shouldNotBe null + } + +// xshould("Have JavaSparkContext instance") { +// @Language("kts") +// val sc = exec("""sc""") +// sc as? JavaSparkContext shouldNotBe null +// } + + xshould("render Datasets") { + @Language("kts") + val html = + execForDisplayText( + """ + val ds = listOf(1, 2, 3).toDS() + ds + """.trimIndent(), + ) + println(html) + + html shouldContain "value" + html shouldContain "1" + html shouldContain "2" + html shouldContain "3" + } + + xshould("render JavaRDDs") { + @Language("kts") + val html = + execForDisplayText( + """ + val rdd: JavaRDD> = listOf( + listOf(1, 2, 3), + listOf(4, 5, 6), + ).toRDD() + rdd + """.trimIndent(), + ) + println(html) + + html shouldContain "1, 2, 3" + html shouldContain "4, 5, 6" + } + + xshould("render JavaRDDs with Arrays") { + @Language("kts") + val html = + execForDisplayText( + """ + val rdd: JavaRDD = rddOf( + intArrayOf(1, 2, 3), + intArrayOf(4, 5, 6), + ) + rdd + """.trimIndent(), + ) + println(html) + + html shouldContain "1, 2, 3" + html shouldContain "4, 5, 6" + } + + xshould("render JavaRDDs with custom class") { + + @Language("kts") + val klass = + exec( + """ + @Sparkify data class Test( + val longFirstName: String, + val second: LongArray, + val somethingSpecial: Map, + ): Serializable + """.trimIndent(), + ) + + @Language("kts") + val html = + execForDisplayText( + """ + val rdd = + listOf( + Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")), + Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")), + ).toRDD() + + rdd + """.trimIndent(), + ) + html shouldContain + """ + +-------------+---------------+--------------------+ + |longFirstName| second| somethingSpecial| + +-------------+---------------+--------------------+ + | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| + | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| + +-------------+---------------+--------------------+ + """.trimIndent() + } + + xshould("render JavaPairRDDs") { + @Language("kts") + val html = + execForDisplayText( + """ + val rdd: JavaPairRDD = rddOf( + t(1, 2), + t(3, 4), + ).toJavaPairRDD() + rdd + """.trimIndent(), + ) + println(html) + + html shouldContain + """ + +---+---+ + | _1| _2| + +---+---+ + | 1| 2| + | 3| 4| + +---+---+ + """.trimIndent() + } + + xshould("render JavaDoubleRDD") { + @Language("kts") + val html = + execForDisplayText( + """ + val rdd: JavaDoubleRDD = rddOf(1.0, 2.0, 3.0, 4.0,).toJavaDoubleRDD() + rdd + """.trimIndent(), + ) + println(html) + + html shouldContain "1.0" + html shouldContain "2.0" + html shouldContain "3.0" + html shouldContain "4.0" + } + + xshould("render Scala RDD") { + @Language("kts") + val html = + execForDisplayText( + """ + val rdd: RDD> = rddOf( + listOf(1, 2, 3), + listOf(4, 5, 6), + ).rdd() + rdd + """.trimIndent(), + ) + println(html) + + html shouldContain "1, 2, 3" + html shouldContain "4, 5, 6" + } + + xshould("truncate dataset cells using properties") { + + @Language("kts") + val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int + + @Language("kts") + val html = + execForDisplayText( + """ + @Sparkify data class Test(val a: String) + sparkProperties.displayTruncate = 3 + dsOf(Test("aaaaaaaaaa")) + """.trimIndent(), + ) + + @Language("kts") + val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") + + html shouldContain "aaa" + html shouldNotContain "aaaaaaaaaa" + } + + xshould("limit dataset rows using properties") { + + @Language("kts") + val oldLimit = exec("""sparkProperties.displayLimit""") as Int + + @Language("kts") + val html = + execForDisplayText( + """ + @Sparkify data class Test(val a: String) + sparkProperties.displayLimit = 3 + dsOf(Test("a"), Test("b"), Test("c"), Test("d"), Test("e")) + """.trimIndent(), + ) + + @Language("kts") + val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") + + html shouldContain "a|" + html shouldContain "b|" + html shouldContain "c|" + html shouldNotContain "d|" + html shouldNotContain "e|" + } + + xshould("truncate rdd cells using properties") { + + @Language("kts") + val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int + + @Language("kts") + val html = + execForDisplayText( + """ + sparkProperties.displayTruncate = 3 + rddOf("aaaaaaaaaa") + """.trimIndent(), + ) + + @Language("kts") + val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") + + html shouldContain "aaa" + html shouldNotContain "aaaaaaaaaa" + } + + xshould("limit rdd rows using properties") { + + @Language("kts") + val oldLimit = exec("""sparkProperties.displayLimit""") as Int + + @Language("kts") + val html = + execForDisplayText( + """ + sparkProperties.displayLimit = 3 + rddOf("a", "b", "c", "d", "e") + """.trimIndent(), + ) + + @Language("kts") + val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") + + html shouldContain " a|" + html shouldContain " b|" + html shouldContain " c|" + html shouldNotContain " d|" + html shouldNotContain " e|" + } + + @Language("kts") + val _stop = exec("""spark.stop()""") + } + } + }) + +// class JupyterStreamingTests : ShouldSpec({ +// val replProvider = ReplProvider { classpath -> +// createRepl( +// httpUtil = createLibraryHttpUtil(), +// scriptClasspath = classpath, +// isEmbedded = true, +// ).apply { +// eval { +// librariesScanner.addLibrariesFromClassLoader( +// classLoader = currentClassLoader, +// host = this, +// notebook = notebook, +// integrationTypeNameRules = listOf( +// PatternNameAcceptanceRule( +// acceptsFlag = false, +// pattern = "org.jetbrains.kotlinx.spark.api.jupyter.**", +// ), +// PatternNameAcceptanceRule( +// acceptsFlag = true, +// pattern = "org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration", +// ), +// ), +// ) +// } +// } +// } +// +// val currentClassLoader = DependsOn::class.java.classLoader +// val scriptClasspath = classpathFromClassloader(currentClassLoader).orEmpty() +// +// fun createRepl(): ReplForJupyter = replProvider(scriptClasspath) +// suspend fun withRepl(action: suspend ReplForJupyter.() -> Unit): Unit = createRepl().action() +// +// xcontext("Jupyter") { +// withRepl { +// +// // For when onInterrupt is implemented in the Jupyter kernel +// should("Have sscCollection instance") { +// +// @Language("kts") +// val sscCollection = exec("""sscCollection""") +// sscCollection as? MutableSet shouldNotBe null +// } +// +// should("Not have spark instance") { +// shouldThrowAny { +// @Language("kts") +// val spark = exec("""spark""") +// Unit +// } +// } +// +// should("Not have sc instance") { +// shouldThrowAny { +// @Language("kts") +// val sc = exec("""sc""") +// Unit +// } +// } +// +// should("stream") { +// +// @Language("kts") +// val value = exec( +// """ +// import java.util.LinkedList +// import org.apache.spark.api.java.function.ForeachFunction +// import org.apache.spark.util.LongAccumulator +// +// +// val input = arrayListOf("aaa", "bbb", "aaa", "ccc") +// +// @Volatile +// var counter: LongAccumulator? = null +// +// withSparkStreaming(Duration(10), timeout = 1_000) { +// +// val queue = withSpark(ssc) { +// LinkedList(listOf(sc.parallelize(input))) +// } +// +// val inputStream = ssc.queueStream(queue) +// +// inputStream.foreachRDD { rdd, _ -> +// withSpark(rdd) { +// if (counter == null) +// counter = sc.sc().longAccumulator() +// +// rdd.toDS().showDS().forEach { +// if (it !in input) error(it + " should be in input") +// counter!!.add(1L) +// } +// } +// } +// } +// counter!!.sum() +// """.trimIndent() +// ) as Long +// +// value shouldBe 4L +// } +// +// } +// } +// }) + +internal fun ReplForJupyter.execEx(code: Code): EvalResultEx = evalEx(EvalRequestData(code)) + +internal fun ReplForJupyter.exec(code: Code): Any? = (execEx(code) as? EvalResultEx.Success)?.renderedValue + +@JvmName("execTyped") +internal inline fun ReplForJupyter.exec(code: Code): T { + val res = exec(code) + res.shouldBeInstanceOf() + return res +} + +internal fun ReplForJupyter.execHtml(code: Code): String { + val res = exec(code) + val html = res["text/html"] + html.shouldNotBeNull() + return html +} + +internal fun ReplForJupyter.execForDisplayText(code: Code): String { + val res = exec(code) + val text = res[MimeTypes.PLAIN_TEXT] + text.shouldNotBeNull() + return text +} + +class Counter( + @Volatile var value: Int, +) : Serializable diff --git a/kotlin-spark-api/build.gradle.kts b/kotlin-spark-api/build.gradle.kts index c4a748e7..276e53ac 100644 --- a/kotlin-spark-api/build.gradle.kts +++ b/kotlin-spark-api/build.gradle.kts @@ -54,6 +54,8 @@ dependencies { hadoopClient, kotlinStdLib, reflect, + ) + api( kotlinDateTime, ) diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index e931890e..04cbe3e7 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -354,6 +354,13 @@ object KotlinTypeInference : Serializable { }.mapKeys { it.key.name } ) { it.simpleName } + //#if sparkConnect == true + if (kClass.hasAnnotation() || UDTRegistration.exists(kClass.jvmName)) { + println("$kClass has a UDT, but UDTs are not supported with Spark-connect. " + + "Try to encode just primitives/java types or make your own @Sparkify data class.") + } + //#endif + return when { // primitives java / kotlin currentType == typeOf() -> AgnosticEncoders.`PrimitiveBooleanEncoder$`.`MODULE$` @@ -416,7 +423,11 @@ object KotlinTypeInference : Serializable { kClass.isSubclassOf(scala.Enumeration.Value::class) -> AgnosticEncoders.ScalaEnumEncoder(jClass.superclass, ClassTag.apply(jClass)) + // TODO test kotlin types + currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_INSTANT_ENCODER() + // udts + //#if sparkConnect == false kClass.hasAnnotation() -> { val annotation = jClass.getAnnotation(SQLUserDefinedType::class.java)!! val udtClass = annotation.udt @@ -432,8 +443,7 @@ object KotlinTypeInference : Serializable { AgnosticEncoders.UDTEncoder(udt, udt.javaClass) } - - currentType.isSubtypeOf() -> TODO("kotlin.time.Duration is unsupported. Use java.time.Duration for now.") + //#endif currentType.isSubtypeOf?>() -> { val elementEncoder = encoderFor( diff --git a/kotlin-spark-api/src/main/resources/kotest.properties b/kotlin-spark-api/src/main/resources/kotest.properties new file mode 100644 index 00000000..77319870 --- /dev/null +++ b/kotlin-spark-api/src/main/resources/kotest.properties @@ -0,0 +1,2 @@ +kotest.framework.classpath.scanning.config.disable=true +kotest.framework.classpath.scanning.autoscan.disable=true \ No newline at end of file diff --git a/spark-connect-examples/build.gradle.kts b/spark-connect-examples/build.gradle.kts index f6885edc..99f8da1d 100644 --- a/spark-connect-examples/build.gradle.kts +++ b/spark-connect-examples/build.gradle.kts @@ -38,21 +38,20 @@ dependencies { Dependencies { -// implementation(hadoopClient) - // IMPORTANT! - compileOnly(sparkSqlApi) - implementation(sparkConnectClient) - - implementation(kotlinDateTime) - - compileOnly(scalaLibrary) + compileOnly( + sparkSqlApi, + scalaLibrary, + ) + implementation( + sparkConnectClient, + ) } } kotlin { jvmToolchain { - languageVersion = JavaLanguageVersion.of(17)//Versions.jvmLanguageVersion + languageVersion = JavaLanguageVersion.of(17) // Versions.jvmLanguageVersion } compilerOptions { jvmTarget = JvmTarget.fromTarget(Versions.jvmTarget) diff --git a/spark-connect-examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/Main.kt b/spark-connect-examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/Main.kt index f5bc5c84..6448e6a9 100644 --- a/spark-connect-examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/Main.kt +++ b/spark-connect-examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/Main.kt @@ -1,5 +1,6 @@ package org.jetbrains.kotlinx.spark.examples +import kotlinx.datetime.Clock import org.apache.spark.sql.connect.client.REPLClassDirMonitor import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify import org.jetbrains.kotlinx.spark.api.showDS @@ -7,7 +8,6 @@ import org.jetbrains.kotlinx.spark.api.toList import org.jetbrains.kotlinx.spark.api.tuples.X import org.jetbrains.kotlinx.spark.api.withSparkConnect import scala.Tuple2 -import java.time.LocalDate // run with `./gradlew runShadow` or set VM options: "--add-opens=java.base/java.nio=ALL-UNNAMED" in the IDE fun main() = @@ -17,21 +17,22 @@ fun main() = // make jar first, preferably a fat jar with shadow, but be careful it doesn't contain scala depencencies spark.addArtifact("/mnt/data/Projects/kotlin-spark-api/spark-connect-examples/build/libs/spark-connect-examples-2.0.0-SNAPSHOT-all.jar") - val data = listOf( - Person("Alice", 25, LocalDate.of(1996, 1, 1), "Alice" X Address("1 Main St", "Springfield", "IL", 62701)), - Person("Bob", 30, LocalDate.of(1991, 1, 1), "Bob" X Address("2 Main St", "Springfield", "IL", 62701)), + Person("Alice", 25, (Clock.System.now()), "Alice" X Address("1 Main St", "Springfield", "IL", 62701)), + Person("Bob", 30, (Clock.System.now()), "Bob" X Address("2 Main St", "Springfield", "IL", 62701)), Person( "Charlie", 35, - LocalDate.of(1986, 1, 1), + (Clock.System.now()), "Charlie" X Address("3 Main St", "Springfield", "IL", 62701), ), ) val ds = data.toDS().showDS() + + ds .filter { it.age > 26 } .toList() @@ -52,6 +53,23 @@ data class Address( data class Person( val name: String, val age: Int, - val birthDate: LocalDate, + val birthDate: kotlinx.datetime.Instant, val tuple: Tuple2, ) +// +//class InstantUdt : UserDefinedType() { +// override fun userClass(): Class = kotlinx.datetime.Instant::class.java +// +// override fun deserialize(datum: Any?): kotlinx.datetime.Instant? = +// when (datum) { +// null -> null +// is Long -> kotlinx.datetime.Instant.fromEpochMilliseconds(datum.microseconds.inWholeMilliseconds) +// +// else -> throw IllegalArgumentException("Unsupported datum: $datum") +// } +// +// override fun serialize(obj: kotlinx.datetime.Instant?): Long? = +// obj?.toEpochMilliseconds()?.milliseconds?.inWholeMicroseconds +// +// override fun sqlType(): DataType = InternalRow +//} \ No newline at end of file