From 7b9224a7a39f078fd986d6fa38e0cc8c69a23427 Mon Sep 17 00:00:00 2001 From: Ihor Date: Thu, 2 Jan 2025 05:11:50 +0100 Subject: [PATCH] Spark 3.5.x support (#1178) * Copy-pasted scala impl for 3.5 from 3.2 * Implemented support for Spark 3.5.1. Fixes are relevant for 3.5.0, 3.4.x, 3.3.4+ as well. * Fixes: Crash on Databricks when using zip deployment, exceptions in console after successful run on windows. Added logging to help with troubleshooting * Updated CI and Nightly pipelines with Spark 3.5, fixed incompatible tests --- azure-pipelines-e2e-tests-template.yml | 14 +- azure-pipelines-pr.yml | 12 +- azure-pipelines.yml | 55 ++- .../DeltaFixture.cs | 1 + .../IpcTests/SparkContextTests.cs | 10 +- .../IpcTests/Sql/CatalogTests.cs | 1 - .../TypeConverterTests.cs | 1 + .../PayloadWriter.cs | 1 + .../TestData.cs | 1 + .../Processor/TaskContextProcessor.cs | 22 +- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 10 +- .../Microsoft.Spark/Sql/Catalog/Catalog.cs | 14 +- .../Microsoft.Spark/Utils/TypeConverter.cs | 5 + src/csharp/Microsoft.Spark/Versions.cs | 1 + src/scala/microsoft-spark-3-5/pom.xml | 83 ++++ .../spark/api/dotnet/CallbackClient.scala | 72 ++++ .../spark/api/dotnet/CallbackConnection.scala | 112 +++++ .../spark/api/dotnet/DotnetBackend.scala | 113 +++++ .../api/dotnet/DotnetBackendHandler.scala | 337 +++++++++++++++ .../spark/api/dotnet/DotnetException.scala | 13 + .../apache/spark/api/dotnet/DotnetRDD.scala | 30 ++ .../apache/spark/api/dotnet/DotnetUtils.scala | 39 ++ .../spark/api/dotnet/JVMObjectTracker.scala | 55 +++ .../spark/api/dotnet/JvmBridgeUtils.scala | 33 ++ .../org/apache/spark/api/dotnet/SerDe.scala | 387 ++++++++++++++++++ .../apache/spark/api/dotnet/ThreadPool.scala | 72 ++++ .../dotnet/DotNetUserAppException.scala | 22 + .../spark/deploy/dotnet/DotnetRunner.scala | 309 ++++++++++++++ .../spark/internal/config/dotnet/Dotnet.scala | 28 ++ .../spark/mllib/api/dotnet/MLUtils.scala | 26 ++ .../sql/api/dotnet/DotnetForeachBatch.scala | 33 ++ .../spark/sql/api/dotnet/SQLUtils.scala | 37 ++ .../org/apache/spark/sql/test/TestUtils.scala | 30 ++ .../org/apache/spark/util/dotnet/Utils.scala | 349 ++++++++++++++++ .../api/dotnet/DotnetBackendHandlerTest.scala | 68 +++ .../spark/api/dotnet/DotnetBackendTest.scala | 39 ++ .../apache/spark/api/dotnet/Extensions.scala | 20 + .../api/dotnet/JVMObjectTrackerTest.scala | 42 ++ .../apache/spark/api/dotnet/SerDeTest.scala | 373 +++++++++++++++++ .../apache/spark/util/dotnet/UtilsTest.scala | 82 ++++ src/scala/pom.xml | 1 + 41 files changed, 2920 insertions(+), 33 deletions(-) create mode 100644 src/scala/microsoft-spark-3-5/pom.xml create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/test/TestUtils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/util/dotnet/Utils.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala create mode 100644 src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala diff --git a/azure-pipelines-e2e-tests-template.yml b/azure-pipelines-e2e-tests-template.yml index 6f926f9d7..8fe5de8c5 100644 --- a/azure-pipelines-e2e-tests-template.yml +++ b/azure-pipelines-e2e-tests-template.yml @@ -122,14 +122,14 @@ stages: script: | echo "Download Hadoop utils for Windows." $hadoopBinaryUrl = "https://github.com/steveloughran/winutils/releases/download/tag_2017-08-29-hadoop-2.8.1-native/hadoop-2.8.1.zip" - # Spark 3.3.3 version binary use Hadoop3 dependency - if ("3.3.3" -contains "${{ test.version }}") { + # Spark 3.3.0+ version binary uses Hadoop3 dependency + if ([version]"3.3.0" -le [version]"${{ test.version }}") { $hadoopBinaryUrl = "https://github.com/SparkSnail/winutils/releases/download/hadoop-3.3.5/hadoop-3.3.5.zip" } curl -k -L -o hadoop.zip $hadoopBinaryUrl Expand-Archive -Path hadoop.zip -Destination . New-Item -ItemType Directory -Force -Path hadoop\bin - if ("3.3.3" -contains "${{ test.version }}") { + if ([version]"3.3.0" -le [version]"${{ test.version }}") { cp hadoop-3.3.5\winutils.exe hadoop\bin # Hadoop 3.3 need to add hadoop.dll to environment varibles to avoid UnsatisfiedLinkError cp hadoop-3.3.5\hadoop.dll hadoop\bin @@ -142,12 +142,8 @@ stages: - pwsh: | echo "Downloading Spark ${{ test.version }}" $sparkBinaryName = "spark-${{ test.version }}-bin-hadoop2.7" - # In spark 3.3.0, 3.3.1, 3.3.2, 3.3.4, the binary name with hadoop2 dependency has changed to spark-${{ test.version }}-bin-hadoop2.tgz - if ("3.3.0", "3.3.1", "3.3.2", "3.3.4" -contains "${{ test.version }}") { - $sparkBinaryName = "spark-${{ test.version }}-bin-hadoop2" - } - # In spark 3.3.3, the binary don't provide hadoop2 version, so we use hadoop3 version - if ("3.3.3" -contains "${{ test.version }}") { + # Spark 3.3.0+ uses Hadoop3 + if ([version]"3.3.0" -le [version]"${{ test.version }}") { $sparkBinaryName = "spark-${{ test.version }}-bin-hadoop3" } curl -k -L -o spark-${{ test.version }}.tgz https://archive.apache.org/dist/spark/spark-${{ test.version }}/${sparkBinaryName}.tgz diff --git a/azure-pipelines-pr.yml b/azure-pipelines-pr.yml index d213cc574..f22b54cc1 100644 --- a/azure-pipelines-pr.yml +++ b/azure-pipelines-pr.yml @@ -31,7 +31,7 @@ variables: backwardCompatibleTestOptions_Linux_3_1: "" forwardCompatibleTestOptions_Linux_3_1: "" - # Skip all forward/backward compatibility tests since Spark 3.2 is not supported before this release. + # Skip all forward/backward compatibility tests since Spark 3.2 and 3.5 are not supported before this release. backwardCompatibleTestOptions_Windows_3_2: "--filter FullyQualifiedName=NONE" forwardCompatibleTestOptions_Windows_3_2: $(backwardCompatibleTestOptions_Windows_3_2) backwardCompatibleTestOptions_Linux_3_2: $(backwardCompatibleTestOptions_Windows_3_2) @@ -41,6 +41,11 @@ variables: forwardCompatibleTestOptions_Windows_3_3: $(backwardCompatibleTestOptions_Windows_3_3) backwardCompatibleTestOptions_Linux_3_3: $(backwardCompatibleTestOptions_Windows_3_3) forwardCompatibleTestOptions_Linux_3_3: $(backwardCompatibleTestOptions_Windows_3_3) + + backwardCompatibleTestOptions_Windows_3_5: "--filter FullyQualifiedName=NONE" + forwardCompatibleTestOptions_Windows_3_5: $(backwardCompatibleTestOptions_Windows_3_5) + backwardCompatibleTestOptions_Linux_3_5: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions_Linux_3_5: $(backwardCompatibleTestOptions_Windows_3_5) # Azure DevOps variables are transformed into environment variables, with these variables we # avoid the first time experience and telemetry to speed up the build. @@ -73,6 +78,11 @@ parameters: - '3.3.2' - '3.3.3' - '3.3.4' + - '3.5.0' + - '3.5.1' + - '3.5.2' + - '3.5.3' + # List of OS types to run E2E tests, run each test in both 'Windows' and 'Linux' environments - name: listOfE2ETestsPoolTypes type: object diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 384ec0779..a6ae0b15b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -31,12 +31,17 @@ variables: backwardCompatibleTestOptions_Linux_3_1: "" forwardCompatibleTestOptions_Linux_3_1: "" - # Skip all forward/backward compatibility tests since Spark 3.2 is not supported before this release. + # Skip all forward/backward compatibility tests since Spark 3.2 and 3.5 are not supported before this release. backwardCompatibleTestOptions_Windows_3_2: "--filter FullyQualifiedName=NONE" forwardCompatibleTestOptions_Windows_3_2: $(backwardCompatibleTestOptions_Windows_3_2) backwardCompatibleTestOptions_Linux_3_2: $(backwardCompatibleTestOptions_Windows_3_2) forwardCompatibleTestOptions_Linux_3_2: $(backwardCompatibleTestOptions_Windows_3_2) + backwardCompatibleTestOptions_Windows_3_5: "--filter FullyQualifiedName=NONE" + forwardCompatibleTestOptions_Windows_3_5: $(backwardCompatibleTestOptions_Windows_3_5) + backwardCompatibleTestOptions_Linux_3_5: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions_Linux_3_5: $(backwardCompatibleTestOptions_Windows_3_5) + # Azure DevOps variables are transformed into environment variables, with these variables we # avoid the first time experience and telemetry to speed up the build. DOTNET_CLI_TELEMETRY_OPTOUT: 1 @@ -413,3 +418,51 @@ stages: testOptions: "" backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_3_2) forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_3_2) + - version: '3.5.0' + enableForwardCompatibleTests: false + enableBackwardCompatibleTests: false + jobOptions: + - pool: 'Windows' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_3_5) + - pool: 'Linux' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_3_5) + - version: '3.5.1' + enableForwardCompatibleTests: false + enableBackwardCompatibleTests: false + jobOptions: + - pool: 'Windows' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_3_5) + - pool: 'Linux' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_3_5) + - version: '3.5.2' + enableForwardCompatibleTests: false + enableBackwardCompatibleTests: false + jobOptions: + - pool: 'Windows' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_3_5) + - pool: 'Linux' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_3_5) + - version: '3.5.3' + enableForwardCompatibleTests: false + enableBackwardCompatibleTests: false + jobOptions: + - pool: 'Windows' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_3_5) + - pool: 'Linux' + testOptions: "" + backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_3_5) + forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_3_5) \ No newline at end of file diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs index c893336f3..54a3e886b 100644 --- a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs +++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs @@ -27,6 +27,7 @@ public DeltaFixture() (3, 3, 2) => "delta-core_2.12:2.3.0", (3, 3, 3) => "delta-core_2.12:2.3.0", (3, 3, 4) => "delta-core_2.12:2.3.0", + (3, 5, _) => "delta-spark_2.12:3.2.0", _ => throw new NotSupportedException($"Spark {sparkVersion} not supported.") }; diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs index 9b87c39d0..0044c3ec4 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs @@ -57,16 +57,22 @@ public void TestSignaturesV2_4_X() /// /// Test signatures for APIs introduced in Spark 3.1.*. + /// In Spark 3.5 Spark throws an exception when trying to delete + /// archive.zip from temp folder, and causes failures of other tests /// - [SkipIfSparkVersionIsLessThan(Versions.V3_1_0)] + [SkipIfSparkVersionIsNotInRange(Versions.V3_1_0, Versions.V3_3_0)] public void TestSignaturesV3_1_X() { SparkContext sc = SparkContext.GetOrCreate(new SparkConf()); string archivePath = $"{TestEnvironment.ResourceDirectory}archive.zip"; + sc.AddArchive(archivePath); - Assert.IsType(sc.ListArchives().ToArray()); + var archives = sc.ListArchives().ToArray(); + + Assert.IsType(archives); + Assert.NotEmpty(archives.Where(a => a.EndsWith("archive.zip"))); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs index f5f37dd91..630fb3c54 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs @@ -59,7 +59,6 @@ public void TestSignaturesV2_4_X() Assert.IsType(catalog.FunctionExists("functionname")); Assert.IsType(catalog.GetDatabase("default")); Assert.IsType(catalog.GetFunction("abs")); - Assert.IsType(catalog.GetFunction(null, "abs")); Assert.IsType(catalog.GetTable("users")); Assert.IsType
(catalog.GetTable("default", "users")); Assert.IsType(catalog.IsCached("users")); diff --git a/src/csharp/Microsoft.Spark.UnitTest/TypeConverterTests.cs b/src/csharp/Microsoft.Spark.UnitTest/TypeConverterTests.cs index 332e0d29b..34fec9f96 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/TypeConverterTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/TypeConverterTests.cs @@ -20,6 +20,7 @@ public void TestBaseCase() Assert.Equal((short)1, TypeConverter.ConvertTo((short)1)); Assert.Equal((ushort)1, TypeConverter.ConvertTo((ushort)1)); Assert.Equal(1, TypeConverter.ConvertTo(1)); + Assert.Equal(1L, TypeConverter.ConvertTo(1)); Assert.Equal(1u, TypeConverter.ConvertTo(1u)); Assert.Equal(1L, TypeConverter.ConvertTo(1L)); Assert.Equal(1ul, TypeConverter.ConvertTo(1ul)); diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs index 4798950c4..a96d6130b 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs @@ -351,6 +351,7 @@ internal PayloadWriter Create(Version version = null) new BroadcastVariableWriterV2_4_X(), new CommandWriterV2_4_X()); case Versions.V3_3_0: + case Versions.V3_5_1: return new PayloadWriter( version, new TaskContextWriterV3_3_X(), diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs index b7a751317..a4e6f49d0 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs @@ -20,6 +20,7 @@ public static IEnumerable VersionData() => new object[] { Versions.V3_0_0 }, new object[] { Versions.V3_2_0 }, new object[] { Versions.V3_3_0 }, + new object[] { Versions.V3_5_1 }, }; internal static Payload GetDefaultPayload() diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs index 9addde22b..8b952d33b 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs @@ -31,31 +31,32 @@ internal TaskContext Process(Stream stream) private static TaskContext ReadTaskContext_2_x(Stream stream) => new() { + IsBarrier = SerDe.ReadBool(stream), + Port = SerDe.ReadInt32(stream), + Secret = SerDe.ReadString(stream), + StageId = SerDe.ReadInt32(stream), PartitionId = SerDe.ReadInt32(stream), AttemptNumber = SerDe.ReadInt32(stream), AttemptId = SerDe.ReadInt64(stream), }; + // Needed for 3.3.0+ + // https://issues.apache.org/jira/browse/SPARK-36173 private static TaskContext ReadTaskContext_3_3(Stream stream) => new() { + IsBarrier = SerDe.ReadBool(stream), + Port = SerDe.ReadInt32(stream), + Secret = SerDe.ReadString(stream), + StageId = SerDe.ReadInt32(stream), PartitionId = SerDe.ReadInt32(stream), AttemptNumber = SerDe.ReadInt32(stream), AttemptId = SerDe.ReadInt64(stream), - // CPUs field is added into TaskContext from 3.3.0 https://issues.apache.org/jira/browse/SPARK-36173 CPUs = SerDe.ReadInt32(stream) }; - private static void ReadBarrierInfo(Stream stream) - { - // Read barrier-related payload. Note that barrier is currently not supported. - SerDe.ReadBool(stream); // IsBarrier - SerDe.ReadInt32(stream); // BoundPort - SerDe.ReadString(stream); // Secret - } - private static void ReadTaskContextProperties(Stream stream, TaskContext taskContext) { int numProperties = SerDe.ReadInt32(stream); @@ -87,7 +88,6 @@ private static class TaskContextProcessorV2_4_X { internal static TaskContext Process(Stream stream) { - ReadBarrierInfo(stream); TaskContext taskContext = ReadTaskContext_2_x(stream); ReadTaskContextProperties(stream, taskContext); @@ -99,7 +99,6 @@ private static class TaskContextProcessorV3_0_X { internal static TaskContext Process(Stream stream) { - ReadBarrierInfo(stream); TaskContext taskContext = ReadTaskContext_2_x(stream); ReadTaskContextResources(stream); ReadTaskContextProperties(stream, taskContext); @@ -112,7 +111,6 @@ private static class TaskContextProcessorV3_3_X { internal static TaskContext Process(Stream stream) { - ReadBarrierInfo(stream); TaskContext taskContext = ReadTaskContext_3_3(stream); ReadTaskContextResources(stream); ReadTaskContextProperties(stream, taskContext); diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index f7bd145e3..119c7bdf2 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -8,6 +8,7 @@ using System.Diagnostics; using System.IO; using System.Net; +using System.Net.Sockets; using System.Text; using System.Threading; using Microsoft.Spark.Network; @@ -184,7 +185,7 @@ private object CallJavaMethod( ISocketWrapper socket = null; try - { + { // Limit the number of connections to the JVM backend. Netty is configured // to use a set number of threads to process incoming connections. Each // new connection is delegated to these threads in a round robin fashion. @@ -299,6 +300,13 @@ private object CallJavaMethod( } else { + if (e.InnerException is SocketException) + { + _logger.LogError( + "Scala worker abandoned the connection, likely fatal crash on Java side. \n" + + "Ensure Spark runs with sufficient memory."); + } + // In rare cases we may hit the Netty connection thread deadlock. // If max backend threads is 10 and we are currently using 10 active // connections (0 in the _sockets queue). When we hit this exception, diff --git a/src/csharp/Microsoft.Spark/Sql/Catalog/Catalog.cs b/src/csharp/Microsoft.Spark/Sql/Catalog/Catalog.cs index ab15be82e..ce7f9125f 100644 --- a/src/csharp/Microsoft.Spark/Sql/Catalog/Catalog.cs +++ b/src/csharp/Microsoft.Spark/Sql/Catalog/Catalog.cs @@ -248,20 +248,22 @@ public Database GetDatabase(string dbName) => new Database((JvmObjectReference)Reference.Invoke("getDatabase", dbName)); /// - /// Get the function with the specified name. If you are trying to get an in-built - /// function then use the unqualified name. + /// Get the function with the specified name. This function can be a temporary function + /// or a function. /// /// Is either a qualified or unqualified name that designates a - /// function. If no database identifier is provided, it refers to a temporary function or - /// a function in the current database. + /// function. It follows the same resolution rule with SQL: search for built-in/temp + /// functions first then functions in the current database(namespace). /// `Function` object which includes the class name, database, description, /// whether it is temporary and the name of the function. public Function GetFunction(string functionName) => new Function((JvmObjectReference)Reference.Invoke("getFunction", functionName)); /// - /// Get the function with the specified name. If you are trying to get an in-built function - /// then pass null as the dbName. + /// Get the function with the specified name in the specified database under the Hive + /// Metastore. + /// To get built-in functions, or functions in other catalogs, please use `getFunction(functionName)` with + /// qualified function name instead. /// /// Is a name that designates a database. Built-in functions will be /// in database null rather than default. diff --git a/src/csharp/Microsoft.Spark/Utils/TypeConverter.cs b/src/csharp/Microsoft.Spark/Utils/TypeConverter.cs index 6c00c79dd..b10190136 100644 --- a/src/csharp/Microsoft.Spark/Utils/TypeConverter.cs +++ b/src/csharp/Microsoft.Spark/Utils/TypeConverter.cs @@ -33,6 +33,11 @@ private static object Convert(object obj, Type toType) { return ConvertToDictionary(hashtable, toType); } + // Fails to convert int to long otherwise + else if (toType.IsPrimitive) + { + return System.Convert.ChangeType(obj, toType); + } return obj; } diff --git a/src/csharp/Microsoft.Spark/Versions.cs b/src/csharp/Microsoft.Spark/Versions.cs index 2f6ce10f5..2d02ca388 100644 --- a/src/csharp/Microsoft.Spark/Versions.cs +++ b/src/csharp/Microsoft.Spark/Versions.cs @@ -13,5 +13,6 @@ internal static class Versions internal const string V3_1_1 = "3.1.1"; internal const string V3_2_0 = "3.2.0"; internal const string V3_3_0 = "3.3.0"; + internal const string V3_5_1 = "3.5.1"; } } diff --git a/src/scala/microsoft-spark-3-5/pom.xml b/src/scala/microsoft-spark-3-5/pom.xml new file mode 100644 index 000000000..660607edd --- /dev/null +++ b/src/scala/microsoft-spark-3-5/pom.xml @@ -0,0 +1,83 @@ + + 4.0.0 + + com.microsoft.scala + microsoft-spark + ${microsoft-spark.version} + + microsoft-spark-3-5_2.12 + 2019 + + UTF-8 + 2.12.18 + 2.12 + 3.5.1 + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + junit + junit + 4.13.1 + test + + + org.specs + specs + 1.2.5 + test + + + + + src/main/scala + src/test/scala + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + + + compile + testCompile + + + + + ${scala.version} + + -target:jvm-1.8 + -deprecation + -feature + + + + + + diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala new file mode 100644 index 000000000..aea355dfa --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataOutputStream + +import org.apache.spark.internal.Logging + +import scala.collection.mutable.Queue + +/** + * CallbackClient is used to communicate with the Dotnet CallbackServer. + * The client manages and maintains a pool of open CallbackConnections. + * Any callback request is delegated to a new CallbackConnection or + * unused CallbackConnection. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() + + private[this] var isShutdown: Boolean = false + + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = + getOrCreateConnection() match { + case Some(connection) => + try { + connection.send(callbackId, writeBody) + addConnection(connection) + } catch { + case e: Exception => + logError(s"Error calling callback [callback id = $callbackId].", e) + connection.close() + throw e + } + case None => throw new Exception("Unable to get or create connection.") + } + + private def getOrCreateConnection(): Option[CallbackConnection] = synchronized { + if (isShutdown) { + logInfo("Cannot get or create connection while client is shutdown.") + return None + } + + if (connectionPool.nonEmpty) { + return Some(connectionPool.dequeue()) + } + + Some(new CallbackConnection(serDe, address, port)) + } + + private def addConnection(connection: CallbackConnection): Unit = synchronized { + assert(connection != null) + connectionPool.enqueue(connection) + } + + def shutdown(): Unit = synchronized { + if (isShutdown) { + logInfo("Shutdown called, but already shutdown.") + return + } + + logInfo("Shutting down.") + connectionPool.foreach(_.close) + connectionPool.clear + isShutdown = true + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala new file mode 100644 index 000000000..604cf029b --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.internal.Logging + +/** + * CallbackConnection is used to process the callback communication + * between the JVM and Dotnet. It uses a TCP socket to communicate with + * the Dotnet CallbackServer and the socket is expected to be reused. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val socket: Socket = new Socket(address, port) + private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) + private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) + + def send( + callbackId: Int, + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { + logInfo(s"Calling callback [callback id = $callbackId] ...") + + try { + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) + + val byteArrayOutputStream = new ByteArrayOutputStream() + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) + byteArrayOutputStream.writeTo(outputStream); + } catch { + case e: Exception => { + throw new Exception("Error writing to stream.", e) + } + } + + logInfo(s"Signaling END_OF_STREAM.") + try { + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + outputStream.flush() + + val endOfStreamResponse = readFlag(inputStream) + endOfStreamResponse match { + case CallbackFlags.END_OF_STREAM => + logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") + case _ => { + throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + + s"Received: $endOfStreamResponse") + } + } + } catch { + case e: Exception => { + throw new Exception("Error while verifying end of stream.", e) + } + } + } + + def close(): Unit = { + try { + serDe.writeInt(outputStream, CallbackFlags.CLOSE) + outputStream.flush() + } catch { + case e: Exception => logInfo("Unable to send close to .NET callback server.", e) + } + + close(socket) + close(outputStream) + close(inputStream) + } + + private def close(s: Socket): Unit = { + try { + assert(s != null) + s.close() + } catch { + case e: Exception => logInfo("Unable to close socket.", e) + } + } + + private def close(c: Closeable): Unit = { + try { + assert(c != null) + c.close() + } catch { + case e: Exception => logInfo("Unable to close closeable.", e) + } + } + + private def readFlag(inputStream: DataInputStream): Int = { + val callbackFlag = serDe.readInt(inputStream) + if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { + val exceptionMessage = serDe.readString(inputStream) + throw new DotnetException(exceptionMessage) + } + callbackFlag + } + + private object CallbackFlags { + val CLOSE: Int = -1 + val CALLBACK: Int = -2 + val DOTNET_EXCEPTION_THROWN: Int = -3 + val END_OF_STREAM: Int = -4 + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala new file mode 100644 index 000000000..c6f528aee --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.net.InetSocketAddress +import java.util.concurrent.TimeUnit +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS +import org.apache.spark.{SparkConf, SparkEnv} + +/** + * Netty server that invokes JVM calls based upon receiving messages from .NET. + * The implementation mirrors the RBackend. + * + */ +class DotnetBackend extends Logging { + self => // for accessing the this reference in inner class(ChannelInitializer) + private[this] var channelFuture: ChannelFuture = _ + private[this] var bootstrap: ServerBootstrap = _ + private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None + + def init(portNumber: Int): Int = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) + logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") + bossGroup = new NioEventLoopGroup(numBackendThreads) + val workerGroup = bossGroup + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast( + "frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort + } + + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully() + } + bootstrap = null + + objectTracker.clear() + + // Send close to .NET callback server. + shutdownCallbackClient() + + // Shutdown the thread pool whose executors could still be running. + ThreadPool.shutdown() + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala new file mode 100644 index 000000000..2863e5b3a --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + +/** + * Handler for DotnetBackend. + * This implementation is similar to RBackendHandler. + */ +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) + extends SimpleChannelInboundHandler[Array[Byte]] + with Logging { + + private[this] val serDe = new SerDe(objectsTracker) + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val reply = handleBackendRequest(msg) + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = serDe.readBoolean(dis) + val processId = serDe.readInt(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) + + if (objId == "DotnetHandler") { + methodName match { + case "stopBackend" => + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + server.close() + case "rm" => + try { + val t = serDe.readObjectType(dis) + assert(t == 'c') + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + serDe.writeInt(dos, -1) + } + case "rmThread" => + try { + assert(serDe.readObjectType(dis) == 'i') + val processId = serDe.readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) + val result = ThreadPool.tryDeleteThread(processId, threadToDelete) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + serDe.writeInt(dos, -1) + } + case "connectCallback" => + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions. + serDe.writeObject(dos, server.callbackClient) + case "closeCallback" => + logInfo("Requesting to close callback client") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + case _ => dos.writeInt(-1) + } + } else { + ThreadPool + .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + } + + bos.toByteArray + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Skip logging the exception message if the connection was disconnected from + // the .NET side so that .NET side doesn't have to explicitly close the connection via + // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, + // so skipping this kind of exception message does not affect the debugging. + if ( + !cause.getMessage.contains("An existing connection was forcibly closed by the remote host") + && !cause.getMessage.contains("Connection reset") + ) { + logError("Exception caught: ", cause) + } + + // Close the connection when an exception is raised. + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + var args: Array[java.lang.Object] = null + var methods: Array[java.lang.reflect.Method] = null + + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + objectsTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + args = readArgs(numArgs, dis) + methods = cls.getMethods + + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) + + if (index.isEmpty) { + logWarning( + s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + + val ret = selectedMethods(index.get).invoke(obj, args: _*) + + // Write status bit + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args: _*) + + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException( + "invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Throwable => + val jvmObj = objectsTracker.get(objId) + val jvmObjName = jvmObj match { + case Some(jObj) => jObj.getClass.getName + case None => "NullObject" + } + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") + + logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + + if (methods != null) { + logDebug(s"All methods for $jvmObjName:") + methods.foreach(m => logDebug(m.toString)) + } + + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + serDe.readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 until numArgs) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + + if (!parameterWrapperType.isInstance(args(i))) { + // non primitive types + if (!parameterType.isPrimitive && args(i) != null) { + return false + } + + // primitive types + if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { + return false + } + } + } + + true + } + + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- parameterTypesOfMethods.indices) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + for (i <- 0 until numArgs) { + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) + } + } + } + None + } + + def logError(id: String, e: Exception): Unit = {} +} + + diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala new file mode 100644 index 000000000..c70d16b03 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala @@ -0,0 +1,13 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +class DotnetException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala new file mode 100644 index 000000000..f5277c215 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD + +object DotnetRDD { + def createPythonRDD( + parent: RDD[_], + func: PythonFunction, + preservePartitoning: Boolean): PythonRDD = { + new PythonRDD(parent, func, preservePartitoning) + } + + def createJavaRDDFromArray( + sc: SparkContext, + arr: Array[Array[Byte]], + numSlices: Int): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(sc.parallelize(arr, numSlices)) + } + + def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd) +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala new file mode 100644 index 000000000..9f556338b --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.JavaConverters._ + +/** DotnetUtils object that hosts some helper functions + * help data type conversions between dotnet and scala + */ +object DotnetUtils { + + /** A helper function to convert scala Map to java.util.Map + * @param value - scala Map + * @return java.util.Map + */ + def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava + + /** Convert java data type to corresponding scala type + * @param value - java.lang.Object + * @return Any + */ + def mapScalaToJava(value: java.lang.Object): Any = { + value match { + case i: java.lang.Integer => i.toInt + case d: java.lang.Double => d.toDouble + case f: java.lang.Float => f.toFloat + case b: java.lang.Boolean => b.booleanValue() + case l: java.lang.Long => l.toLong + case s: java.lang.Short => s.toShort + case by: java.lang.Byte => by.toByte + case c: java.lang.Character => c.toChar + case _ => value + } + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..81cfaf88b --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala new file mode 100644 index 000000000..06a476f67 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.SparkConf + +/* + * Utils for JvmBridge. + */ +object JvmBridgeUtils { + def getKeyValuePairAsString(kvp: (String, String)): String = { + return kvp._1 + "=" + kvp._2 + } + + def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = { + val sb = new StringBuilder + + for (kvp <- kvpArray) { + sb.append(getKeyValuePairAsString(kvp)) + sb.append(";") + } + + sb.toString + } + + def getSparkConfAsString(sparkConf: SparkConf): String = { + getKeyValuePairArrayAsString(sparkConf.getAll) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala new file mode 100644 index 000000000..a3df3788a --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{DataInputStream, DataOutputStream} +import java.nio.charset.StandardCharsets +import java.sql.{Date, Time, Timestamp} + +import org.apache.spark.sql.Row + +import scala.collection.JavaConverters._ + +/** + * Class responsible for serialization and deserialization between CLR & JVM. + * This implementation of methods is mostly identical to the SerDe implementation in R. + */ +class SerDe(val tracker: JVMObjectTracker) { + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'g' => new java.lang.Long(readLong(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => tracker.getObject(readString(dis)) + case 'R' => readRowArr(dis) + case 'O' => readObjectArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + private def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + private def readLong(in: DataInputStream): Long = { + in.readLong() + } + + private def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + private def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + val str = new String(bytes, "UTF-8") + str + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + readStringBytes(in, len) + } + + def readBoolean(in: DataInputStream): Boolean = { + in.readBoolean() + } + + private def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + private def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + + private def readRow(in: DataInputStream): Row = { + val len = readInt(in) + Row.fromSeq((0 until len).map(_ => readObject(in))) + } + + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + private def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + private def readLongArr(in: DataInputStream): Array[Long] = { + val len = readInt(in) + (0 until len).map(_ => readLong(in)).toArray + } + + private def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + val len = readInt(in) + (0 until len).map(_ => readDoubleArr(in)).toArray + } + + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + private def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + private def readRowArr(in: DataInputStream): java.util.List[Row] = { + val len = readInt(in) + (0 until len).map(_ => readRow(in)).toList.asJava + } + + private def readObjectArr(in: DataInputStream): Seq[Any] = { + val len = readInt(in) + (0 until len).map(_ => readObject(in)) + } + + private def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'g' => readLongArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'A' => readDoubleArrArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) + keys.zip(values).toMap.asJava + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Using the same mapping as SparkR implementation for now + // Methods to write out data from Java to .NET. + // + // Type mapping from Java to .NET: + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Float -> double + // Double -> double + // Long -> long + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "doublearray" => dos.writeByte('A') + case "long" => dos.writeByte('g') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null || value == Unit) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "long" | "java.lang.Long" => + writeType(dos, "long") + writeLong(dos, value.asInstanceOf[Long]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeLongArr(dos, value.asInstanceOf[Array[Long]]) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[[D" => + writeType(dos, "list") + writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeLong(out: DataOutputStream, value: Long): Unit = { + out.writeLong(value) + } + + private def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + out.writeBoolean(value) + } + + private def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + private def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } + + def writeString(out: DataOutputStream, value: String): Unit = { + val utf8 = value.getBytes(StandardCharsets.UTF_8) + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) + } + + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = tracker.put(value) + writeString(out, objId) + } + + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + writeType(out, "long") + out.writeInt(value.length) + value.foreach(v => out.writeLong(v)) + } + + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + writeType(out, "doublearray") + out.writeInt(value.length) + value.foreach(v => writeDoubleArr(out, v)) + } + + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..50551a7d9 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from (processId, threadId) to corresponding executor. + */ + private val executors: mutable.HashMap[(Int, Int), ExecutorService] = + new mutable.HashMap[(Int, Int), ExecutorService]() + + /** + * Run some code on a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(processId: Int, threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(processId, threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized { + executors.remove((processId, threadId)) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Shutdown any running ExecutorServices. + */ + def shutdown(): Unit = synchronized { + executors.foreach(_._2.shutdown()) + executors.clear() + } + + /** + * Get the executor if it exists, otherwise create a new one. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized { + executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala new file mode 100644 index 000000000..4551a70bd --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import org.apache.spark.SparkException + +/** + * This exception type describes an exception thrown by a .NET user application. + * + * @param exitCode Exit code returned by the .NET application. + * @param dotNetStackTrace Stacktrace extracted from .NET application logs. + */ +private[spark] class DotNetUserAppException(exitCode: Int, dotNetStackTrace: Option[String]) + extends SparkException( + dotNetStackTrace match { + case None => s"User application exited with $exitCode" + case Some(e) => s"User application exited with $exitCode and .NET exception: $e" + }) diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala new file mode 100644 index 000000000..a3fa10551 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import java.io.File +import java.net.URI +import java.nio.file.attribute.PosixFilePermissions +import java.nio.file.{FileSystems, Files, Paths} +import java.util.Locale +import java.util.concurrent.{Semaphore, TimeUnit} + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.output.TeeOutputStream +import org.apache.hadoop.fs.Path +import org.apache.spark +import org.apache.spark.api.dotnet.DotnetBackend +import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.{ + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK, + ERROR_BUFFER_SIZE, ERROR_REDIRECITON_ENABLED +} +import org.apache.spark.util.dotnet.{Utils => DotnetUtils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException} + +import scala.collection.JavaConverters._ +import scala.io.StdIn +import scala.util.Try + +/** + * DotnetRunner class used to launch Spark .NET applications using spark-submit. + * It executes .NET application as a subprocess and then has it connect back to + * the JVM to access system properties etc. + */ +object DotnetRunner extends Logging { + private val DEBUG_PORT = 5567 + private val supportedSparkMajorMinorVersionPrefix = "3.5" + private val supportedSparkVersions = Set[String]("3.5.0", "3.5.1", "3.5.2", "3.5.3") + + val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION) + + def main(args: Array[String]): Unit = { + if (args.length == 0) { + throw new IllegalArgumentException("At least one argument is expected.") + } + + DotnetUtils.validateSparkVersions( + sys.props + .getOrElse( + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key, + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString) + .toBoolean, + spark.SPARK_VERSION, + SPARK_VERSION, + supportedSparkMajorMinorVersionPrefix, + supportedSparkVersions) + + val settings = initializeSettings(args) + + // Determines if this needs to be run in debug mode. + // In debug mode this runner will not launch a .NET process. + val runInDebugMode = settings._1 + @volatile var dotnetBackendPortNumber = settings._2 + var dotnetExecutable = "" + var otherArgs: Array[String] = null + + if (!runInDebugMode) { + if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) { + var zipFileName = args(0) + val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI) + val workingDir = new File("").getAbsoluteFile + val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath())) + + // Standalone cluster mode where .NET application is remotely located. + if (zipFileUri.getScheme() != "file") { + zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName + } + + logInfo(s"Unzipping .NET driver $zipFileName to $driverDir") + DotnetUtils.unzip(new File(zipFileName), driverDir) + + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1))) + otherArgs = args.slice(2, args.length) + } else { + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(args(0)) + otherArgs = args.slice(1, args.length) + } + } else { + otherArgs = args.slice(1, args.length) + } + + val processParameters = new java.util.ArrayList[String] + processParameters.add(dotnetExecutable) + otherArgs.foreach(arg => processParameters.add(arg)) + + logInfo(s"Starting DotnetBackend with $dotnetExecutable.") + + // Time to wait for DotnetBackend to initialize in seconds. + val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt + + // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our + // Java system properties etc. + val dotnetBackend = new DotnetBackend() + val initialized = new Semaphore(0) + val dotnetBackendThread = new Thread("DotnetBackend") { + override def run() { + // need to get back dotnetBackendPortNumber because if the value passed to init is 0 + // the port number is dynamically assigned in the backend + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) + logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + initialized.release() + dotnetBackend.run() + } + } + + dotnetBackendThread.start() + + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + if (!runInDebugMode) { + var returnCode = -1 + var process: Process = null + val enableLogRedirection: Boolean = sys.props + .getOrElse( + ERROR_REDIRECITON_ENABLED.key, + ERROR_REDIRECITON_ENABLED.defaultValue.get.toString).toBoolean + val stderrBuffer: Option[CircularBuffer] = Option(enableLogRedirection).collect { + case true => new CircularBuffer( + sys.props.getOrElse( + ERROR_BUFFER_SIZE.key, + ERROR_BUFFER_SIZE.defaultValue.get.toString).toInt) + } + + try { + val builder = new ProcessBuilder(processParameters) + val env = builder.environment() + env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) + + for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + env.put(key, value) + logInfo(s"Adding key=$key and value=$value to environment") + } + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + process = builder.start() + + // Redirect stdin of JVM process to stdin of .NET process. + new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start() + // Redirect stdout and stderr of .NET process to System.out and to buffer + // if log direction is enabled. If not, redirect only to System.out. + new RedirectThread( + process.getInputStream, + stderrBuffer match { + case Some(buffer) => new TeeOutputStream(System.out, buffer) + case _ => System.out + }, + "redirect .NET stdout and stderr").start() + + process.waitFor() + } catch { + case t: Throwable => + logThrowable(t) + } finally { + returnCode = closeDotnetProcess(process) + closeBackend(dotnetBackend) + } + if (returnCode != 0) { + if (stderrBuffer.isDefined) { + throw new DotNetUserAppException(returnCode, Some(stderrBuffer.get.toString)) + } else { + throw new SparkUserAppException(returnCode) + } + } else { + logInfo(s".NET application exited successfully") + } + // TODO: The following is causing the following error: + // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16, + // (reason: Shutdown hook called before final status was reported.) + // DotnetUtils.exit(returnCode) + } else { + // scalastyle:off println + println("***********************************************************************") + println("* .NET Backend running debug mode. Press enter to exit *") + println("***********************************************************************") + // scalastyle:on println + + StdIn.readLine() + closeBackend(dotnetBackend) + DotnetUtils.exit(0) + } + } else { + logError(s"DotnetBackend did not initialize in $backendTimeout seconds") + DotnetUtils.exit(-1) + } + } + + // When the executable is downloaded as part of zip file, check if the file exists + // after zip file is unzipped under the given dir. Once it is found, change the + // permission to executable (only for Unix systems, since the zip file may have been + // created under Windows. Finally, the absolute path for the executable is returned. + private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = { + val path = Paths.get(dir.getAbsolutePath, dotnetExecutable) + val resolvedExecutable = if (Files.isRegularFile(path)) { + path.toAbsolutePath.toString + } else { + Files + .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath)) + .iterator() + .asScala + .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match { + case Some(path) => path.toAbsolutePath.toString + case None => + throw new IllegalArgumentException( + s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}") + } + } + + if (DotnetUtils.supportPosix) { + Files.setPosixFilePermissions( + Paths.get(resolvedExecutable), + PosixFilePermissions.fromString("rwxr-xr-x")) + } + + resolvedExecutable + } + + /** + * Download HDFS file into the supplied directory and return its local path. + * Will throw an exception if there are errors during downloading. + */ + private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = { + val sparkConf = new SparkConf() + val filePath = new Path(hdfsFilePath) + + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val jarFileName = filePath.getName + val localFile = new File(driverDir, jarFileName) + + if (!localFile.exists()) { // May already exist if running multiple workers on one node + logInfo(s"Copying user file $filePath to $driverDir") + DotnetUtils.fetchFileWithbackwardCompatibility( + hdfsFilePath, + new File(driverDir), + sparkConf, + hadoopConf, + System.currentTimeMillis(), + useCache = false) + } + + if (!localFile.exists()) { + throw new Exception(s"Did not see expected $jarFileName in $driverDir") + } + + localFile + } + + private def closeBackend(dotnetBackend: DotnetBackend): Unit = { + logInfo("Closing DotnetBackend") + dotnetBackend.close() + } + + private def closeDotnetProcess(dotnetProcess: Process): Int = { + if (dotnetProcess == null) { + return -1 + } else if (!dotnetProcess.isAlive) { + return dotnetProcess.exitValue() + } + + // Try to (gracefully on Linux) kill the process and resort to force if interrupted + var returnCode = -1 + logInfo("Closing .NET process") + try { + dotnetProcess.destroy() + returnCode = dotnetProcess.waitFor() + } catch { + case _: InterruptedException => + logInfo( + "Thread interrupted while waiting for graceful close. Forcefully closing .NET process") + returnCode = dotnetProcess.destroyForcibly().waitFor() + case t: Throwable => + logThrowable(t) + } + + returnCode + } + + private def initializeSettings(args: Array[String]): (Boolean, Int) = { + val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( + "debug") + var portNumber = 0 + if (runInDebugMode) { + if (args.length == 1) { + portNumber = DEBUG_PORT + } else if (args.length == 2) { + portNumber = Integer.parseInt(args(1)) + } + } + + (runInDebugMode, portNumber) + } + + private def logThrowable(throwable: Throwable): Unit = + logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}") +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala new file mode 100644 index 000000000..18ba4c6e5 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.internal.config.dotnet + +import org.apache.spark.internal.config.ConfigBuilder + +private[spark] object Dotnet { + val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf + .createWithDefault(10) + + val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK = + ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf + .createWithDefault(false) + + val ERROR_REDIRECITON_ENABLED = + ConfigBuilder("spark.nonjvm.error.forwarding.enabled").booleanConf + .createWithDefault(false) + + val ERROR_BUFFER_SIZE = + ConfigBuilder("spark.nonjvm.error.buffer.size") + .intConf + .checkValue(_ >= 0, "The error buffer size must not be negative") + .createWithDefault(10240) +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala new file mode 100644 index 000000000..3e3c3e0e3 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala @@ -0,0 +1,26 @@ + +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.mllib.api.dotnet + +import org.apache.spark.ml._ +import scala.collection.JavaConverters._ + +/** MLUtils object that hosts helper functions + * related to ML usage + */ +object MLUtils { + + /** A helper function to let pipeline accept java.util.ArrayList + * format stages in scala code + * @param pipeline - The pipeline to be set stages + * @param value - A java.util.ArrayList of PipelineStages to be set as stages + * @return The pipeline + */ + def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline = + pipeline.setStages(value.asScala.toArray) +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala new file mode 100644 index 000000000..5d06d4304 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.api.dotnet.CallbackClient +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.streaming.DataStreamWriter + +class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging { + def call(batchDF: DataFrame, batchId: Long): Unit = + callbackClient.send( + callbackId, + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) + }) +} + +object DotnetForeachBatchHelper { + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") + } + + dsw.foreachBatch(dotnetForeachFunc.call _) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala new file mode 100644 index 000000000..31dcd061b --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction, SimplePythonFunction} +import org.apache.spark.broadcast.Broadcast + +object SQLUtils { + + /** + * Exposes createPythonFunction to the .NET client to enable registering UDFs. + */ + def createPythonFunction( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVersion: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2): PythonFunction = { + // From 3.4.0 use SimplePythonFunction. https://github.com/apache/spark/commit/18ff15729268def5ee1bdf5dfcb766bd1d699684 + SimplePythonFunction( + command, + envVars, + pythonIncludes, + pythonExec, + pythonVersion, + broadcastVars, + accumulator) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/test/TestUtils.scala new file mode 100644 index 000000000..1cd45aa95 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/sql/test/TestUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.MemoryStream + +object TestUtils { + + /** + * Helper method to create typed MemoryStreams intended for use in unit tests. + * @param sqlContext The SQLContext. + * @param streamType The type of memory stream to create. This string is the `Name` + * property of the dotnet type. + * @return A typed MemoryStream. + */ + def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = { + import sqlContext.implicits._ + + streamType match { + case "Int32" => MemoryStream[Int] + case "String" => MemoryStream[String] + case _ => throw new Exception(s"$streamType not supported") + } + } +} diff --git a/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/util/dotnet/Utils.scala new file mode 100644 index 000000000..27e48c018 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/main/scala/org/apache/spark/util/dotnet/Utils.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import java.io._ +import java.nio.file.attribute.PosixFilePermission +import java.nio.file.attribute.PosixFilePermission._ +import java.nio.file.{FileSystems, Files} +import java.util.{Timer, TimerTask} +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.hadoop.conf.Configuration +import org.apache.spark.util.Utils +import java.io.File +import java.lang.NoSuchMethodException +import java.lang.reflect.InvocationTargetException +import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile} +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK + +import scala.collection.JavaConverters._ +import scala.collection.Set + +/** + * Utility methods. + */ +object Utils extends Logging { + private val posixFilePermissions = Array( + OWNER_READ, + OWNER_WRITE, + OWNER_EXECUTE, + GROUP_READ, + GROUP_WRITE, + GROUP_EXECUTE, + OTHERS_READ, + OTHERS_WRITE, + OTHERS_EXECUTE) + + val supportPosix: Boolean = + FileSystems.getDefault.supportedFileAttributeViews().contains("posix") + + /** + * Provides a backward-compatible implementation of the `fetchFile` method + * from Apache Spark's `org.apache.spark.util.Utils` class. + * + * This method handles differences in method signatures between Spark versions, + * specifically the inclusion or absence of a `SecurityManager` parameter. It uses + * reflection to dynamically resolve and invoke the correct version of `fetchFile`. + * + * @param url The source URL of the file to be fetched. + * @param targetDir The directory where the fetched file will be saved. + * @param conf The Spark configuration object used to determine runtime settings. + * @param hadoopConf Hadoop configuration settings for file access. + * @param timestamp A timestamp indicating the cache validity of the fetched file. + * @param useCache Whether to use Spark's caching mechanism to reuse previously downloaded files. + * @param shouldUntar Whether to untar the downloaded file if it is a tarball. Defaults to `true`. + * + * @return A `File` object pointing to the fetched and stored file. + * + * @throws IllegalArgumentException If neither method signature is found. + * @throws Throwable If an error occurs during reflection or method invocation. + * + * Note: + * - This method was introduced as a fix for DataBricks-specific file copying issues + * and was referenced in PR #1048. + * - Reflection is used to ensure compatibility across Spark environments. + */ + def fetchFileWithbackwardCompatibility( + url: String, + targetDir: File, + conf: SparkConf, + hadoopConf: Configuration, + timestamp: Long, + useCache: Boolean, + shouldUntar: Boolean = true): File = { + + val signatureWithSecurityManager = Array( + classOf[String], + classOf[File], + classOf[SparkConf], + classOf[SecurityManager], + classOf[Configuration], + java.lang.Long.TYPE, + java.lang.Boolean.TYPE, + java.lang.Boolean.TYPE + ) + + val signatureWithoutSecurityManager = Array( + classOf[String], + classOf[File], + classOf[SparkConf], + classOf[Configuration], + classOf[Long], + classOf[Boolean], + classOf[Boolean] + ) + + val utilsClass = Class.forName("org.apache.spark.util.Utils$") + val utilsObject = utilsClass.getField("MODULE$").get(null) + + val (needSecurityManagerArg, method) = { + try { + (true, utilsClass.getMethod("fetchFile", signatureWithSecurityManager: _*)) + } catch { + case _: NoSuchMethodException => + (false, utilsClass.getMethod("fetchFile", signatureWithoutSecurityManager: _*)) + } + } + + val args: Seq[Any] = + Seq( + url, + targetDir, + conf + ) ++ (if (needSecurityManagerArg) Seq(null) else Nil) ++ Seq( + hadoopConf, + timestamp, + useCache, + shouldUntar) + + // Unwrap InvocationTargetException to preserve exception in case of errors: + try { + method.invoke(utilsObject, args.map(_.asInstanceOf[Object]): _*).asInstanceOf[File] + } catch { + case e: InvocationTargetException => + throw e.getCause() + } + } + + /** + * Compress all files under given directory into one zip file and drop it to the target directory + * + * @param sourceDir source directory to zip + * @param targetZipFile target zip file + */ + def zip(sourceDir: File, targetZipFile: File): Unit = { + var fos: FileOutputStream = null + var zos: ZipArchiveOutputStream = null + try { + fos = new FileOutputStream(targetZipFile) + zos = new ZipArchiveOutputStream(fos) + + val sourcePath = sourceDir.toPath + FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file => + var in: FileInputStream = null + try { + val path = file.toPath + val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString) + if (supportPosix) { + entry.setUnixMode( + permissionsToMode(Files.getPosixFilePermissions(path).asScala) + | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4)) + } else if (entry.getName.endsWith(".exe")) { + entry.setUnixMode(0x1ED) // 755 + } else { + entry.setUnixMode(0x1A4) // 644 + } + zos.putArchiveEntry(entry) + + in = new FileInputStream(file) + IOUtils.copy(in, zos) + zos.closeArchiveEntry() + } finally { + IOUtils.closeQuietly(in) + } + } + } finally { + IOUtils.closeQuietly(zos) + IOUtils.closeQuietly(fos) + } + } + + /** + * Unzip a file to the given directory + * + * @param file file to be unzipped + * @param targetDir target directory + */ + def unzip(file: File, targetDir: File): Unit = { + var zipFile: ZipFile = null + try { + targetDir.mkdirs() + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.foreach { entry => + val targetFile = new File(targetDir, entry.getName) + + if (targetFile.exists()) { + logWarning( + s"Target file/directory $targetFile already exists. Skip it for now. " + + s"Make sure this is expected.") + } else { + if (entry.isDirectory) { + targetFile.mkdirs() + } else { + targetFile.getParentFile.mkdirs() + val input = zipFile.getInputStream(entry) + val output = new FileOutputStream(targetFile) + IOUtils.copy(input, output) + IOUtils.closeQuietly(input) + IOUtils.closeQuietly(output) + if (supportPosix) { + val permissions = modeToPermissions(entry.getUnixMode) + // When run in Unix system, permissions will be empty, thus skip + // setting the empty permissions (which will empty the previous permissions). + if (permissions.nonEmpty) { + Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava) + } + } + } + } + } + } catch { + case e: Exception => logError("exception caught during decompression:" + e) + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + /** + * Exits the JVM, trying to do it nicely, otherwise doing it nastily. + * + * @param status the exit status, zero for OK, non-zero for error + * @param maxDelayMillis the maximum delay in milliseconds + */ + def exit(status: Int, maxDelayMillis: Long) { + try { + logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis") + + // setup a timer, so if nice exit fails, the nasty exit happens + val timer = new Timer() + timer.schedule(new TimerTask() { + + override def run() { + Runtime.getRuntime.halt(status) + } + }, maxDelayMillis) + // try to exit nicely + System.exit(status); + } catch { + // exit nastily if we have a problem + case _: Throwable => Runtime.getRuntime.halt(status) + } finally { + // should never get here + Runtime.getRuntime.halt(status) + } + } + + /** + * Exits the JVM, trying to do it nicely, wait 1 second + * + * @param status the exit status, zero for OK, non-zero for error + */ + def exit(status: Int): Unit = { + exit(status, 1000) + } + + /** + * Normalize the Spark version by taking the first three numbers. + * For example: + * x.y.z => x.y.z + * x.y.z.xxx.yyy => x.y.z + * x.y => x.y + * + * @param version the Spark version to normalize + * @return Normalized Spark version. + */ + def normalizeSparkVersion(version: String): String = { + version + .split('.') + .take(3) + .zipWithIndex + .map({ + case (element, index) => { + index match { + case 2 => element.split("\\D+").lift(0).getOrElse("") + case _ => element + } + } + }) + .mkString(".") + } + + /** + * Validates the normalized spark version by verifying: + * - Spark version starts with sparkMajorMinorVersionPrefix. + * - If ignoreSparkPatchVersion is + * - true: valid + * - false: check if the spark version is in supportedSparkVersions. + * @param ignoreSparkPatchVersion Ignore spark patch version. + * @param sparkVersion The spark version. + * @param normalizedSparkVersion: The normalized spark version. + * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against. + * @param supportedSparkVersions The set of supported spark versions. + */ + def validateSparkVersions( + ignoreSparkPatchVersion: Boolean, + sparkVersion: String, + normalizedSparkVersion: String, + supportedSparkMajorMinorVersionPrefix: String, + supportedSparkVersions: Set[String]): Unit = { + if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) { + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (ignoreSparkPatchVersion) { + logWarning( + s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (!supportedSparkVersions(normalizedSparkVersion)) { + val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ") + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '$supportedVersions'." + + "Patch version can be ignored, use setting 'spark.dotnet.ignoreSparkPatchVersionCheck'" ) + } + } + + private[spark] def listZipFileEntries(file: File): Array[String] = { + var zipFile: ZipFile = null + try { + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.map(_.getName).toArray + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = { + posixFilePermissions.foldLeft(0) { (mode, perm) => + (mode << 1) | (if (permissions.contains(perm)) 1 else 0) + } + } + + private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = { + posixFilePermissions.zipWithIndex + .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 } + .map(_._1) + .toSet + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..7088537e1 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import Extensions._ +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + serDe.writeInt(m, 1) // processId + serDe.writeInt(m, 1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readNBytes(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..445486bbd --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows(classOf[Exception], () => { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + }) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..c6904403b --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..41401d680 --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,373 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters._ + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + mapAsJavaMap(Map( + 11 -> true, + 22 -> 42.42, + 33 -> null)), + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(mapAsJavaMap(Map()), serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows(classOf[NoSuchElementException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + seqAsJavaList(Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99)))), + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream (in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala new file mode 100644 index 000000000..736aa20bd --- /dev/null +++ b/src/scala/microsoft-spark-3-5/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK +import org.junit.Assert.{assertEquals, assertThrows} +import org.junit.Test + +@Test +class UtilsTest { + + @Test + def shouldIgnorePatchVersion(): Unit = { + val sparkVersion = "3.5.1" + val sparkMajorMinorVersionPrefix = "3.5" + val supportedSparkVersions = Set[String]("3.5.0") + + Utils.validateSparkVersions( + true, + sparkVersion, + Utils.normalizeSparkVersion(sparkVersion), + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + } + + @Test + def shouldThrowForUnsupportedVersion(): Unit = { + val sparkVersion = "3.5.1" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.5" + val supportedSparkVersions = Set[String]("3.5.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'." + + "Patch version can be ignored, use setting 'spark.dotnet.ignoreSparkPatchVersionCheck'", + + exception.getMessage) + } + + @Test + def shouldThrowForUnsupportedMajorMinorVersion(): Unit = { + val sparkVersion = "3.3.0" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.5" + val supportedSparkVersions = Set[String]("3.5.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.", + exception.getMessage) + } +} diff --git a/src/scala/pom.xml b/src/scala/pom.xml index 06d30cb13..c69f4e035 100644 --- a/src/scala/pom.xml +++ b/src/scala/pom.xml @@ -16,6 +16,7 @@ microsoft-spark-3-1 microsoft-spark-3-2 microsoft-spark-3-3 + microsoft-spark-3-5